diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index aef8006b4a84..4fb357fe2e7c 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -15,8 +15,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - name: Set up Ruby uses: ruby/setup-ruby@v1 with: diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index c0e7a55b8428..0efefc19a132 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -13,18 +13,13 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' - - name: Install sbt - if: matrix.os == 'macos-latest' - run: brew install sbt + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - name: Install php if: matrix.os == 'macos-latest' run: brew install php - - name: Set up Ruby - uses: ruby/setup-ruby@v1 - with: - ruby-version: 2.7 - name: Set up Go uses: actions/setup-go@v5 with: @@ -36,8 +31,6 @@ jobs: with: development: true swift-version: "5.10" - - name: Install Bundler - run: gem install bundler -v 2.4.22 - name: Delete `.rustup` directory run: rm -rf /home/runner/.rustup # to save disk space if: runner.os == 'Linux' @@ -61,8 +54,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - uses: actions/cache@v4 with: path: | @@ -75,6 +70,8 @@ jobs: if: ${{ failure() }} - name: Validate CITATION.cff uses: dieghernan/cff-validator@v3 + with: + install-r: true test-scripts: runs-on: ubuntu-latest @@ -85,8 +82,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - uses: actions/cache@v4 with: path: | @@ -100,7 +99,14 @@ jobs: ./joern --src /tmp/foo --run scan ./joern-scan /tmp/foo ./joern-scan --dump - ./joern-slice data-flow -o target/slice + - name: Joern Slice Testing + run: | + mkdir /tmp/slice + ./joern-slice data-flow tests/code/javasrc/SliceTest.java -o /tmp/slice/dataflow-slice-javasrc.json + echo "checking that the script output contains the content we expect:" + ./joern --script "tests/test-dataflow-slice.sc" --param sliceFile=/tmp/slice/dataflow-slice-javasrc.json | grep 'List(boolean b, b, this, s, "MALICIOUS", s, new Foo("MALICIOUS"), s, s, "SAFE", s, b, this, this, b, s, System.out)' + - name: SARIF Export Testing + run: ./tests/finding-to-sarif-test.sh - run: | cd joern-cli/target/universal/stage ./schema-extender/test.sh diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3f544f609161..83b1b3bf6fd8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,8 +15,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - run: sudo apt update && sudo apt install -y gnupg - run: echo $PGP_SECRET | base64 --decode | gpg --batch --import env: diff --git a/.github/workflows/upgrade-deps.yml b/.github/workflows/upgrade-deps.yml index fd10ff676fdb..2d3e887805c8 100644 --- a/.github/workflows/upgrade-deps.yml +++ b/.github/workflows/upgrade-deps.yml @@ -17,8 +17,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v4 with: - distribution: 'temurin' - java-version: '21' + distribution: temurin + java-version: 21 + cache: sbt + - uses: sbt/setup-sbt@v1 - uses: actions/cache@v4 with: path: | diff --git a/.gitignore b/.gitignore index 8f72cc50071b..e513037d1581 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,5 @@ flake.lock **/.bsp +/joern-cli/frontends/c2cpg/eclipse-cdt/build/ +/joern-cli/frontends/c2cpg/eclipse-cdt/org.eclipse.cdt.core-*.jar diff --git a/Dockerfile b/Dockerfile index 9c82db69dc19..e37e5927744b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,11 @@ -FROM alpine:3.17.3 +FROM alpine:latest # dependencies RUN apk update && apk upgrade && apk add --no-cache openjdk17-jdk python3 git curl gnupg bash nss ncurses php RUN ln -sf python3 /usr/bin/python # sbt -ENV SBT_VERSION 1.8.0 +ENV SBT_VERSION 1.10.3 ENV SBT_HOME /usr/local/sbt ENV PATH ${PATH}:${SBT_HOME}/bin RUN curl -sL "https://github.com/sbt/sbt/releases/download/v$SBT_VERSION/sbt-$SBT_VERSION.tgz" | gunzip | tar -x -C /usr/local diff --git a/README.md b/README.md index c12c8cab1544..77627b4192af 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Specification: https://cpg.joern.io ## News / Changelog +- Joern v4.0.0 [migrates from overflowdb to flatgraph](changelog/4.0.0-flatgraph.md) - Joern v2.0.0 [upgrades from Scala2 to Scala3](changelog/2.0.0-scala3.md) - Joern v1.2.0 removes the `overflowdb.traversal.Traversal` class. This change is not completely backwards compatible. See [here](changelog/traversal_removal.md) for a detailed writeup. diff --git a/build.sbt b/build.sbt index 897fa0fb8d67..b4cc8ee3907b 100644 --- a/build.sbt +++ b/build.sbt @@ -1,8 +1,8 @@ name := "joern" ThisBuild / organization := "io.joern" -ThisBuild / scalaVersion := "3.4.2" +ThisBuild / scalaVersion := "3.5.2" -val cpgVersion = "1.6.16" +val cpgVersion = "1.7.26" lazy val joerncli = Projects.joerncli lazy val querydb = Projects.querydb @@ -45,7 +45,8 @@ ThisBuild / compile / javacOptions ++= Seq( ThisBuild / scalacOptions ++= Seq( "-deprecation", // Emit warning and location for usages of deprecated APIs. "--release", - "11" + "11", + "-Wshadow:type-parameter-shadow", ) lazy val createDistribution = taskKey[File]("Create a complete Joern distribution") @@ -73,7 +74,8 @@ Global / onChangedBuildSource := ReloadOnSourceChanges // publishing info for sonatype / maven central ThisBuild / publishTo := sonatypePublishToBundle.value -sonatypeCredentialHost := "s01.oss.sonatype.org" +ThisBuild / sonatypeCredentialHost := xerial.sbt.Sonatype.sonatypeCentralHost + ThisBuild / scmInfo := Some(ScmInfo(url("https://github.com/joernio/joern"), "scm:git@github.com:joernio/joern.git")) ThisBuild / homepage := Some(url("https://joern.io/")) ThisBuild / licenses := List("Apache-2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0")) diff --git a/changelog/4.0.0-flatgraph.md b/changelog/4.0.0-flatgraph.md new file mode 100644 index 000000000000..2e3de7ed0de7 --- /dev/null +++ b/changelog/4.0.0-flatgraph.md @@ -0,0 +1,43 @@ +# 4.0.x: Migration to flatgraph + +Joern uses the domain-specific classes from codepropertygraph, which (up to joern 2.x) were generated by overflowdb (specifically https://github.com/ShiftLeftSecurity/overflowdb and https://github.com/ShiftLeftSecurity/overflowdb-codegen). +As of joern 4.0.x we replaced overflowdb with it's successor, [flatgraph](https://github.com/joernio/flatgraph). The most important PRs paving the way for flatgraph are https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1769 and https://github.com/joernio/joern/pull/4630. + +### Why the change? +Most importantly, flatgraph brings us about 40% less memory usage as well as faster traversals. The reduced memory footprint is achieved by flatgraph's efficient columnar layout, essentially we hold everything in few (albeit very large) arrays. +The faster traversals account for about 40% performance improvement for many joern use-cases, e.g. running the default passes while importing a large cpg into joern. Some numbers for Linux 4, as an example for a very large codebase. Numbers are based on my workstation and just rough measurements. + +Linux 4.1.16, cpg created with c2cpg after `importCpg` into joern: +* 48M nodes with 630M properties (mostly `String` and `Integer`) +* 431M edges with 115M properties (all `String`) + +| | joern 2 (overflowdb) | joern 4 (flatgraph) | +| --------------------------------------------|----------------------|-------------------- | +| heap after import (after garbage collection)| 33g | 20g | +| minimum required heap (Xmx) for import | 80g | 30g | +| time for importCpg | 18 minutes | 11 minutes | +| file size on disk | 2600M | 400m | + +Linux 5 and 6 are considerably larger, so I wasn't able to import them into joern 2 on my workstation (which has 128G physical memory). With joern 4 it works just fine with `./joern -J-Xmx90g` for linux6 😀 + +Also worth noting: one of overflowdb's features was the overflowing-to-disk mechanism. While it sounds nice to be able to handle graphs larger than the available memory, in practice it was too slow to be useful, so we didn't reimplement it in flatgraph. + +### API changes / upgrade guide +We tried to minimise the joern-user-facing API changes, and depending on your usage you may not notice anything at all. That being said, if your code makes use of the `overflowdb` namespace then you will have to make some changes. In most cases, it's simply a namespace change to `flatgraph`. Since hopefully no joern user used the overflowdb api (with one exception listed below), I won't list the changes here, instead please look at the [joern migration PR](https://github.com/joernio/joern/pull/4630/files) and/or ask us on [discord](https://discord.com/channels/832209896089976854/842699104383533076). + +Most relevant changes: +1) `overflowdb.BatchedUpdate.applyDiff` -> `flatgraph.DiffGraphApplier.applyDiff` +1) `io.shiftleft.passes.IntervalKeyPool` -> `io.joern.x2cpg.utils.IntervalKeyPool` + +1) `StoredNode.propertyOption` now returns an `Option` rather than a `java.util.Optional` - the API is almost identical, and there's builtin conversions both ways (`.toScala|.toJava` via `import scala.jdk.OptionConverters.*`). + +1) the arrow syntax for quickly constructing graphs, e.g. `v0 --- "CFG" --> v1`, quite useful for testing, doesn't exist in flatgraph yet. You'll need to create a diffgraph instead. There's plenty of examples in the [joern migration PR](https://github.com/joernio/joern/pull/4630/files). + +1) Edges can only have zero or one properties. Since the codepropertygraph schema never defined more than one property per edge type, this should not affect you as a joern user, unless you've extended the cpg schema... + +### Credits and kudos +Flatgraph is based on [@bbrehm](https://github.com/bbrehm)'s great ideas for a memory efficient columnar layout on the jvm. He built a working prototype with very promising benchmarks that convinced us that the effort to migrate is worth-while, and that turned out to be true. + +### Why did we leave out version 3? +I'm glad you asked! Version 3 is typically a source for trouble, you know... just look at Gnome 3, Python 3 and many more. The only exception is Scala 3, of course - ymmv :) + diff --git a/console/src/main/scala/io/joern/console/BridgeBase.scala b/console/src/main/scala/io/joern/console/BridgeBase.scala index cf1071ccb7dd..183be94d90b1 100644 --- a/console/src/main/scala/io/joern/console/BridgeBase.scala +++ b/console/src/main/scala/io/joern/console/BridgeBase.scala @@ -2,10 +2,12 @@ package io.joern.console import better.files.* import io.shiftleft.codepropertygraph.generated.Languages +import io.shiftleft.semanticcpg.sarif.SarifConfig import org.apache.commons.text.StringEscapeUtils import replpp.scripting.ScriptRunner import java.nio.file.{Files, Path} +import scala.collection.mutable import scala.jdk.CollectionConverters.* import scala.util.Try @@ -13,7 +15,8 @@ case class Config( scriptFile: Option[Path] = None, command: Option[String] = None, params: Map[String, String] = Map.empty, - additionalImports: Seq[Path] = Nil, + predefFiles: Seq[Path] = Nil, + runBefore: Seq[String] = Nil, additionalClasspathEntries: Seq[String] = Seq.empty, addPlugin: Option[String] = None, rmPlugin: Option[String] = None, @@ -70,8 +73,15 @@ trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandli .valueName("script1.sc") .unbounded() .optional() - .action((x, c) => c.copy(additionalImports = c.additionalImports :+ x)) - .text("import (and run) additional script(s) on startup - may be passed multiple times") + .action((x, c) => c.copy(predefFiles = c.predefFiles :+ x)) + .text("given source files will be compiled and added to classpath - this may be passed multiple times") + + opt[String]("runBefore") + .valueName("'import Int.MaxValue'") + .unbounded() + .optional() + .action((x, c) => c.copy(runBefore = c.runBefore :+ x)) + .text("given code will be executed on startup - this may be passed multiple times") opt[String]("classpathEntry") .valueName("path/to/classpath") @@ -211,14 +221,21 @@ trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandli } } - protected def createPredefFile(additionalLines: Seq[String] = Nil): Path = { - val tmpFile = Files.createTempFile("joern-predef", "sc") - Files.write(tmpFile, (predefLines ++ additionalLines).asJava) - tmpFile.toAbsolutePath - } - /** code that is executed on startup */ - protected def predefLines: Seq[String] + protected def runBeforeCode: Seq[String] + + protected def buildRunBeforeCode(config: Config): Seq[String] = { + val builder = Seq.newBuilder[String] + builder ++= runBeforeCode + config.cpgToLoad.foreach { cpgFile => + builder += s"""importCpg("$cpgFile")""" + } + config.forInputPath.foreach { name => + builder += s"""openForInputPath("$name")""".stripMargin + } + builder ++= config.runBefore + builder.result() + } protected def greeting: String @@ -229,19 +246,10 @@ trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandli trait InteractiveShell { this: BridgeBase => protected def startInteractiveShell(config: Config) = { - val replConfig = config.cpgToLoad.map { cpgFile => - "importCpg(\"" + cpgFile + "\")" - } ++ config.forInputPath.map { name => - s""" - |openForInputPath(\"$name\") - |""".stripMargin - } - - val predefFile = createPredefFile(replConfig.toSeq) - replpp.InteractiveShell.run( replpp.Config( - predefFiles = predefFile +: config.additionalImports, + predefFiles = config.predefFiles, + runBefore = buildRunBeforeCode(config), nocolors = config.nocolors, verbose = config.verbose, classpathConfig = replpp.Config @@ -268,10 +276,10 @@ trait ScriptExecution { this: BridgeBase => if (!Files.exists(scriptFile)) { Try(throw new AssertionError(s"given script file `$scriptFile` does not exist")) } else { - val predefFile = createPredefFile(importCpgCode(config)) val scriptReturn = ScriptRunner.exec( replpp.Config( - predefFiles = predefFile +: config.additionalImports, + predefFiles = config.predefFiles, + runBefore = buildRunBeforeCode(config), scriptFile = Option(scriptFile), command = config.command, params = config.params, @@ -286,18 +294,6 @@ trait ScriptExecution { this: BridgeBase => scriptReturn } } - - /** For the given config, generate a list of commands to import the CPG - */ - private def importCpgCode(config: Config): List[String] = { - config.cpgToLoad.map { cpgFile => - "importCpg(\"" + cpgFile + "\")" - }.toList ++ config.forInputPath.map { name => - s""" - |openForInputPath(\"$name\") - |""".stripMargin - } - } } trait PluginHandling { this: BridgeBase => @@ -406,10 +402,9 @@ trait PluginHandling { this: BridgeBase => trait ServerHandling { this: BridgeBase => protected def startHttpServer(config: Config): Unit = { - val predefFile = createPredefFile(Nil) - val baseConfig = replpp.Config( - predefFiles = predefFile +: config.additionalImports, + predefFiles = config.predefFiles, + runBefore = buildRunBeforeCode(config), verbose = true, // always print what's happening - helps debugging classpathConfig = replpp.Config .ForClasspath(inheritClasspath = true, dependencies = config.dependencies, resolvers = config.resolvers) diff --git a/console/src/main/scala/io/joern/console/Commit.scala b/console/src/main/scala/io/joern/console/Commit.scala index 402b16d4c619..acd53ddbe24a 100644 --- a/console/src/main/scala/io/joern/console/Commit.scala +++ b/console/src/main/scala/io/joern/console/Commit.scala @@ -3,7 +3,7 @@ package io.joern.console import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder object Commit { val overlayName: String = "commit" @@ -26,7 +26,7 @@ class Commit(opts: CommitOptions) extends LayerCreator { builder.absorb(opts.diffGraphBuilder) } } - runPass(pass, context) + pass.createAndApply() opts.diffGraphBuilder = Cpg.newDiffGraphBuilder } diff --git a/console/src/main/scala/io/joern/console/Console.scala b/console/src/main/scala/io/joern/console/Console.scala index 9f8ac55e2b17..63236e5de1e2 100644 --- a/console/src/main/scala/io/joern/console/Console.scala +++ b/console/src/main/scala/io/joern/console/Console.scala @@ -12,8 +12,8 @@ import io.shiftleft.codepropertygraph.cpgloading.CpgLoader import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.dotextension.ImageViewer import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext} -import overflowdb.traversal.help.Doc -import overflowdb.traversal.help.Table.AvailableWidthProvider +import io.shiftleft.codepropertygraph.generated.help.Doc +import flatgraph.help.Table.AvailableWidthProvider import scala.sys.process.Process import scala.util.control.NoStackTrace @@ -349,10 +349,14 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur val cpgDestinationPath = cpgDestinationPathOpt.get - if (CpgLoader.isLegacyCpg(cpgFile)) { - report("You have provided a legacy proto CPG. Attempting conversion.") + val isProtoFormat = CpgLoader.isProtoFormat(cpgFile.path) + val isOverflowDbFormat = CpgLoader.isOverflowDbFormat(cpgFile.path) + if (isProtoFormat || isOverflowDbFormat) { + if (isProtoFormat) report("You have provided a legacy proto CPG. Attempting conversion.") + else if (isOverflowDbFormat) report("You have provided a legacy overflowdb CPG. Attempting conversion.") try { - CpgConverter.convertProtoCpgToOverflowDb(cpgFile.path.toString, cpgDestinationPath.toString) + val cpg = CpgLoader.load(cpgFile.path, cpgDestinationPath) + cpg.close() } catch { case exc: Exception => report("Error converting legacy CPG: " + exc.getMessage) diff --git a/console/src/main/scala/io/joern/console/ConsoleConfig.scala b/console/src/main/scala/io/joern/console/ConsoleConfig.scala index a5b4ea738f95..fc86ca54d857 100644 --- a/console/src/main/scala/io/joern/console/ConsoleConfig.scala +++ b/console/src/main/scala/io/joern/console/ConsoleConfig.scala @@ -1,6 +1,6 @@ package io.joern.console -import better.files._ +import better.files.* import scala.annotation.tailrec import scala.collection.mutable diff --git a/console/src/main/scala/io/joern/console/CpgConverter.scala b/console/src/main/scala/io/joern/console/CpgConverter.scala index 3019e65ab816..7ce22c6bb670 100644 --- a/console/src/main/scala/io/joern/console/CpgConverter.scala +++ b/console/src/main/scala/io/joern/console/CpgConverter.scala @@ -1,14 +1,18 @@ package io.joern.console -import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, CpgLoaderConfig} -import overflowdb.Config +import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, ProtoCpgLoader} + +import java.nio.file.Paths object CpgConverter { - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = { - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close + def convertProtoCpgToFlatgraph(srcFilename: String, dstFilename: String): Unit = { + val cpg = ProtoCpgLoader.loadFromProtoZip(srcFilename, Option(Paths.get(dstFilename))) + cpg.close() } + @deprecated("method got renamed to `convertProtoCpgToFlatgraph, please use that instead", "joern v3") + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + convertProtoCpgToFlatgraph(srcFilename, dstFilename) + } diff --git a/console/src/main/scala/io/joern/console/Help.scala b/console/src/main/scala/io/joern/console/Help.scala index a1b3e017709b..f7cf3dab982f 100644 --- a/console/src/main/scala/io/joern/console/Help.scala +++ b/console/src/main/scala/io/joern/console/Help.scala @@ -1,19 +1,25 @@ package io.joern.console -import overflowdb.traversal.help.DocFinder.* -import overflowdb.traversal.help.Table.AvailableWidthProvider -import overflowdb.traversal.help.{DocFinder, Table} +import flatgraph.help.DocFinder.* +import flatgraph.help.Table.{AvailableWidthProvider, Row} +import flatgraph.help.{DocFinder, Table} object Help { + /** allows users to extend the help table with additional entries */ + val additionalHelpEntries = Seq.newBuilder[Tuple3[String, String, String]] + def overview(clazz: Class[?])(using AvailableWidthProvider): String = { val columnNames = List("command", "description", "example") - val rows = DocFinder - .findDocumentedMethodsOf(clazz) - .map { case StepDoc(_, funcName, doc) => - List(funcName, doc.info, doc.example) - } - .toList ++ List(runRow) + + val rows = Seq.newBuilder[Row] + rows += runRow + DocFinder.findDocumentedMethodsOf(clazz).foreach { case StepDoc(_, funcName, doc) => + rows += List(funcName, doc.info, doc.example) + } + additionalHelpEntries.result().foreach { case (a, b, c) => + rows += List(a, b, c) + } val header = formatNoQuotes(""" | @@ -27,7 +33,7 @@ object Help { | | |""".stripMargin) - header + "\n" + Table(columnNames, rows.sortBy(_.head)).render + header + "\n" + Table(columnNames, rows.result().sortBy(_.head)).render } def format(text: String): String = { @@ -45,8 +51,7 @@ object Help { List("run", "Run analyzer on active CPG", "run.securityprofile") // Since `run` is generated dynamically, it's not picked up when looking - // through methods via reflection, and therefore, we are adding - // it manually. + // through methods via reflection, and therefore, we are adding it manually. def runLongHelp: String = Help.format(""" | @@ -60,11 +65,10 @@ object Help { } .mkString("\n") - val overview = Help.overview(clazz) s""" | class Helper() { | def run: String = Help.runLongHelp - | override def toString: String = \"\"\"$overview\"\"\" + | override def toString: String = Help.overview(classOf[${clazz.getName}]) | | $membersCode | } diff --git a/console/src/main/scala/io/joern/console/PluginManager.scala b/console/src/main/scala/io/joern/console/PluginManager.scala index 0d049a1ab640..0666355bc429 100644 --- a/console/src/main/scala/io/joern/console/PluginManager.scala +++ b/console/src/main/scala/io/joern/console/PluginManager.scala @@ -1,5 +1,5 @@ package io.joern.console -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File import better.files.File.apply diff --git a/console/src/main/scala/io/joern/console/Run.scala b/console/src/main/scala/io/joern/console/Run.scala index 5dfe92451f25..30ed0f72d505 100644 --- a/console/src/main/scala/io/joern/console/Run.scala +++ b/console/src/main/scala/io/joern/console/Run.scala @@ -6,7 +6,7 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext} import org.reflections8.Reflections import org.reflections8.util.{ClasspathHelper, ConfigurationBuilder} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* object Run { @@ -22,7 +22,7 @@ object Run { query.store()(builder) } } - runPass(pass, context) + pass.createAndApply() } }) } @@ -64,7 +64,7 @@ object Run { | |val opts = new OptsDynamic() | - | import _root_.overflowdb.BatchedUpdate.DiffGraphBuilder + | import _root_.io.shiftleft.codepropertygraph.generated.DiffGraphBuilder | implicit def _diffGraph: DiffGraphBuilder = opts.commit.diffGraphBuilder | def diffGraph = _diffGraph |""".stripMargin @@ -75,7 +75,7 @@ object Run { val toStringCode = s""" - | import overflowdb.traversal.help.Table + | import flatgraph.help.Table | override def toString() : String = { | val columnNames = List("name", "description") | val rows = diff --git a/console/src/main/scala/io/joern/console/cpgcreation/CSharpSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/CSharpSrcCpgGenerator.scala new file mode 100644 index 000000000000..4f41c7bb978a --- /dev/null +++ b/console/src/main/scala/io/joern/console/cpgcreation/CSharpSrcCpgGenerator.scala @@ -0,0 +1,20 @@ +package io.joern.console.cpgcreation + +import io.joern.console.FrontendConfig +import java.nio.file.Path +import scala.util.Try + +case class CSharpSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { + private lazy val command: Path = + if (isWin) rootPath.resolve("csharpsrc2cpg.bat") else rootPath.resolve("csharpsrc2cpg") + private lazy val cmdLineArgs = config.cmdLineParams.toSeq + + override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = { + val arguments = cmdLineArgs ++ Seq(inputPath, "--output", outputPath) + runShellCommand(command.toString, arguments).map(_ => outputPath) + } + + override def isAvailable: Boolean = command.toFile.exists + + override def isJvmBased: Boolean = true +} diff --git a/console/src/main/scala/io/joern/console/cpgcreation/CpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/CpgGenerator.scala index e429124dc67a..6a816135227f 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/CpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/CpgGenerator.scala @@ -3,7 +3,7 @@ package io.joern.console.cpgcreation import better.files.File import io.shiftleft.codepropertygraph.generated.Cpg -import scala.sys.process._ +import scala.sys.process.* import scala.util.Try /** A CpgGenerator generates Code Property Graphs from code. Each supported language implements a Generator, e.g., diff --git a/console/src/main/scala/io/joern/console/cpgcreation/CpgGeneratorFactory.scala b/console/src/main/scala/io/joern/console/cpgcreation/CpgGeneratorFactory.scala index 710ba95363dc..37362dfad732 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/CpgGeneratorFactory.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/CpgGeneratorFactory.scala @@ -1,11 +1,10 @@ package io.joern.console.cpgcreation -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File -import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, CpgLoaderConfig} +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader import io.shiftleft.codepropertygraph.generated.Languages -import io.joern.console.ConsoleConfig -import overflowdb.Config +import io.joern.console.{ConsoleConfig, CpgConverter} import java.nio.file.Path import scala.util.Try @@ -60,12 +59,12 @@ class CpgGeneratorFactory(config: ConsoleConfig) { generator.generate(inputPath, outputPath).map(File(_)) outputFileOpt.map { outFile => val parentPath = outFile.parent.path.toAbsolutePath - if (isZipFile(outFile)) { + if (CpgLoader.isProtoFormat(outFile.path)) { report("Creating database from bin.zip") val srcFilename = outFile.path.toAbsolutePath.toString val dstFilename = parentPath.resolve("cpg.bin").toAbsolutePath.toString // MemoryHelper.hintForInsufficientMemory(srcFilename).map(report) - convertProtoCpgToOverflowDb(srcFilename, dstFilename) + convertProtoCpgToFlatgraph(srcFilename, dstFilename) } else { report("moving cpg.bin.zip to cpg.bin because it is already a database file") val srcPath = parentPath.resolve("cpg.bin.zip") @@ -77,18 +76,13 @@ class CpgGeneratorFactory(config: ConsoleConfig) { } } - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = { - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close - File(srcFilename).delete() - } + @deprecated("method got renamed to `convertProtoCpgToFlatgraph, please use that instead", "joern v3") + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + convertProtoCpgToFlatgraph(srcFilename, dstFilename) - def isZipFile(file: File): Boolean = { - val bytes = file.bytes - Try { - bytes.next() == 'P' && bytes.next() == 'K' - }.getOrElse(false) + def convertProtoCpgToFlatgraph(srcFilename: String, dstFilename: String): Unit = { + CpgConverter.convertProtoCpgToFlatgraph(srcFilename, dstFilename) + File(srcFilename).delete() } private def report(str: String): Unit = System.err.println(str) diff --git a/console/src/main/scala/io/joern/console/cpgcreation/ImportCode.scala b/console/src/main/scala/io/joern/console/cpgcreation/ImportCode.scala index 1f683d49a741..b5340e2300d8 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/ImportCode.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/ImportCode.scala @@ -5,8 +5,8 @@ import io.joern.console.workspacehandling.Project import io.joern.console.{ConsoleException, FrontendConfig, Reporting} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages -import overflowdb.traversal.help.Table -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table +import flatgraph.help.Table.AvailableWidthProvider import java.nio.file.Path import scala.util.{Failure, Success, Try} @@ -52,7 +52,7 @@ class ImportCode[T <: Project](console: io.joern.console.Console[T])(implicit new BinaryFrontend("jvm", Languages.JAVA, "Java/Dalvik Bytecode Frontend (based on SOOT's jimple)") def ghidra: Frontend = new BinaryFrontend("ghidra", Languages.GHIDRA, "ghidra reverse engineering frontend") def kotlin: SourceBasedFrontend = - new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kotlin") + new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kt") def python: SourceBasedFrontend = new SourceBasedFrontend("python", Languages.PYTHONSRC, "Python Source Frontend", "py") def golang: SourceBasedFrontend = new SourceBasedFrontend("golang", Languages.GOLANG, "Golang Source Frontend", "go") @@ -62,11 +62,12 @@ class ImportCode[T <: Project](console: io.joern.console.Console[T])(implicit new JsFrontend("jssrc", Languages.JSSRC, "Javascript/Typescript Source Frontend based on astgen", "js") def swiftsrc: SourceBasedFrontend = new SwiftSrcFrontend("swiftsrc", Languages.SWIFTSRC, "Swift Source Frontend based on swiftastgen", "swift") - def csharp: Frontend = new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") + def csharp: Frontend = new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") + def csharpsrc: SourceBasedFrontend = + new SourceBasedFrontend("csharpsrc", Languages.CSHARPSRC, "C# Source Frontend based on DotNetAstGen", "cs") def llvm: Frontend = new BinaryFrontend("llvm", Languages.LLVM, "LLVM Bitcode Frontend") def php: SourceBasedFrontend = new SourceBasedFrontend("php", Languages.PHP, "PHP source frontend", "php") - def ruby: SourceBasedFrontend = new RubyFrontend("Ruby source frontend", false) - def rubyDeprecated: SourceBasedFrontend = new RubyFrontend("Ruby source deprecated frontend", true) + def ruby: SourceBasedFrontend = SourceBasedFrontend("ruby", Languages.RUBYSRC, "Ruby source frontend", "rb") private def allFrontends: List[Frontend] = List( @@ -85,7 +86,7 @@ class ImportCode[T <: Project](console: io.joern.console.Console[T])(implicit python, csharp, ruby, - rubyDeprecated + csharpsrc ) // this is only abstract to force people adding frontends to make a decision whether the frontend consumes binaries or source @@ -140,30 +141,6 @@ class ImportCode[T <: Project](console: io.joern.console.Console[T])(implicit } } - /** Only a wrapper so as to more easily pick the deprecated variant without having to provide the - * `--useDeprecatedFrontend` flag each time. - * - * @param useDeprecatedFrontend - * If set, will invoke the frontend with the `--useDeprecatedFrontend` flag - */ - private class RubyFrontend(description: String, useDeprecatedFrontend: Boolean = false) - extends SourceBasedFrontend("ruby", Languages.RUBYSRC, description, "rb") { - private val deprecatedFlag = "--useDeprecatedFrontend" - - private def addDeprecatedFlagIfNeeded(args: List[String]): List[String] = { - Option.when(useDeprecatedFrontend && !args.contains(deprecatedFlag))(deprecatedFlag).toList ++ args - } - - override def cpgGeneratorForLanguage( - language: String, - config: FrontendConfig, - rootPath: Path, - args: List[String] - ): Option[CpgGenerator] = { - super.cpgGeneratorForLanguage(language, config, rootPath, addDeprecatedFlagIfNeeded(args)) - } - } - class SwiftSrcFrontend(name: String, language: String, description: String, extension: String) extends SourceBasedFrontend(name, language, description, extension) { override def apply(inputPath: String, projectName: String, args: List[String]): Cpg = { diff --git a/console/src/main/scala/io/joern/console/cpgcreation/JavaCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/JavaCpgGenerator.scala index 1d9277902471..307a4d1793b7 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/JavaCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/JavaCpgGenerator.scala @@ -3,7 +3,7 @@ package io.joern.console.cpgcreation import io.joern.console.FrontendConfig import java.nio.file.Path -import scala.sys.process._ +import scala.sys.process.* import scala.util.{Failure, Try} /** Language frontend for Java archives (JAR files). Translates Java archives into code property graphs. diff --git a/console/src/main/scala/io/joern/console/cpgcreation/JavaSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/JavaSrcCpgGenerator.scala index 09a5fb003cb7..769a6e15cfd7 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/JavaSrcCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/JavaSrcCpgGenerator.scala @@ -6,22 +6,20 @@ import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig import io.shiftleft.codepropertygraph.generated.Cpg import java.nio.file.Path -import scala.compiletime.uninitialized import scala.util.Try /** Source-based front-end for Java */ case class JavaSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { private lazy val command: Path = if (isWin) rootPath.resolve("javasrc2cpg.bat") else rootPath.resolve("javasrc2cpg") - private var enableTypeRecovery = false - private var typeRecoveryConfig: XTypeRecoveryConfig = uninitialized + private lazy val cmdLineArgs = config.cmdLineParams.toSeq + private lazy val enableTypeRecovery = cmdLineArgs.exists(_ == s"--${javasrc2cpg.ParameterNames.EnableTypeRecovery}") + private lazy val typeRecoveryConfig = XTypeRecoveryConfig.parse(cmdLineArgs) /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. */ override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = { - val arguments = config.cmdLineParams.toSeq ++ Seq(inputPath, "--output", outputPath) - enableTypeRecovery = arguments.exists(_ == s"--${javasrc2cpg.ParameterNames.EnableTypeRecovery}") - if (enableTypeRecovery) typeRecoveryConfig = XTypeRecoveryConfig.parse(arguments) + val arguments = cmdLineArgs ++ Seq(inputPath, "--output", outputPath) runShellCommand(command.toString, arguments).map(_ => outputPath) } diff --git a/console/src/main/scala/io/joern/console/cpgcreation/JsSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/JsSrcCpgGenerator.scala index 7d2ce5c5eb93..0ec4347ba392 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/JsSrcCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/JsSrcCpgGenerator.scala @@ -7,18 +7,15 @@ import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig import io.shiftleft.codepropertygraph.generated.Cpg import java.nio.file.Path -import scala.compiletime.uninitialized import scala.util.Try case class JsSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { private lazy val command: Path = if (isWin) rootPath.resolve("jssrc2cpg.bat") else rootPath.resolve("jssrc2cpg.sh") - private var typeRecoveryConfig: XTypeRecoveryConfig = uninitialized + private lazy val typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. */ override def generate(inputPath: String, outputPath: String = "cpg.bin.zip"): Try[String] = { - typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) - if (File(inputPath).isDirectory) { invoke(inputPath, outputPath) } else { diff --git a/console/src/main/scala/io/joern/console/cpgcreation/PhpCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/PhpCpgGenerator.scala index 765d9a824c29..a570751769e0 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/PhpCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/PhpCpgGenerator.scala @@ -7,24 +7,23 @@ import io.shiftleft.codepropertygraph.generated.Cpg import scopt.OParser import java.nio.file.Path -import scala.compiletime.uninitialized import scala.util.Try case class PhpCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("php2cpg.bat") else rootPath.resolve("php2cpg") - private var typeRecoveryConfig: XTypeRecoveryConfig = uninitialized - private var setKnownTypesConfig: XTypeStubsParserConfig = uninitialized - - override def generate(inputPath: String, outputPath: String): Try[String] = { - val cmdLineArgs = config.cmdLineParams.toSeq - typeRecoveryConfig = XTypeRecoveryConfig.parse(cmdLineArgs) - setKnownTypesConfig = OParser + private lazy val command: Path = if (isWin) rootPath.resolve("php2cpg.bat") else rootPath.resolve("php2cpg") + private lazy val cmdLineArgs = config.cmdLineParams.toSeq + private lazy val typeRecoveryConfig = XTypeRecoveryConfig.parse(cmdLineArgs) + private lazy val setKnownTypesConfig: XTypeStubsParserConfig = { + OParser .parse(XTypeStubsParser.parserOptions2, cmdLineArgs, XTypeStubsParserConfig()) .getOrElse( throw new RuntimeException( s"unable to parse XTypeStubsParserConfig from commandline arguments ${cmdLineArgs.mkString(" ")}" ) ) + } + + override def generate(inputPath: String, outputPath: String): Try[String] = { val arguments = List(inputPath) ++ Seq("-o", outputPath) ++ config.cmdLineParams runShellCommand(command.toString, arguments).map(_ => outputPath) } diff --git a/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala index bf73f66be772..211b5226f6c9 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala @@ -1,26 +1,21 @@ package io.joern.console.cpgcreation import io.joern.console.FrontendConfig -import io.joern.x2cpg.X2Cpg import io.joern.x2cpg.frontendspecific.pysrc2cpg import io.joern.x2cpg.frontendspecific.pysrc2cpg.* -import io.joern.x2cpg.passes.base.AstLinkerPass -import io.joern.x2cpg.passes.callgraph.NaiveCallLinker import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig import io.shiftleft.codepropertygraph.generated.Cpg import java.nio.file.Path import scala.util.Try -import scala.compiletime.uninitialized case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") - private var typeRecoveryConfig: XTypeRecoveryConfig = uninitialized + private lazy val command: Path = if (isWin) rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") + private lazy val typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. */ override def generate(inputPath: String, outputPath: String = "cpg.bin.zip"): Try[String] = { - typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) val arguments = Seq(inputPath, "-o", outputPath) ++ config.cmdLineParams runShellCommand(command.toString, arguments).map(_ => outputPath) } @@ -30,7 +25,6 @@ case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends override def applyPostProcessingPasses(cpg: Cpg): Cpg = { pysrc2cpg.postProcessingPasses(cpg, typeRecoveryConfig).foreach(_.createAndApply()) - cpg } diff --git a/console/src/main/scala/io/joern/console/cpgcreation/RubyCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/RubyCpgGenerator.scala index 72fd425b9fb7..82f945090e89 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/RubyCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/RubyCpgGenerator.scala @@ -21,9 +21,7 @@ case class RubyCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgG override def isJvmBased = true override def applyPostProcessingPasses(cpg: Cpg): Cpg = { - // TODO: here we need a Ruby-specific `Config`, which we shall build from the existing `FrontendConfig`. We only - // care for `--useDeprecatedFrontend` though, for now. Nevertheless, there should be a better way of handling this. - val rubyConfig = Config().withUseDeprecatedFrontend(config.cmdLineParams.exists(_ == "--useDeprecatedFrontend")) + val rubyConfig = Config() RubySrc2Cpg.postProcessingPasses(cpg, rubyConfig).foreach(_.createAndApply()) cpg } diff --git a/console/src/main/scala/io/joern/console/cpgcreation/SwiftSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/SwiftSrcCpgGenerator.scala index ed0eed6f6f48..504750f4ef62 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/SwiftSrcCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/SwiftSrcCpgGenerator.scala @@ -7,13 +7,13 @@ import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig import io.shiftleft.codepropertygraph.generated.Cpg import java.nio.file.Path -import scala.compiletime.uninitialized import scala.util.Try case class SwiftSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { private lazy val command: Path = if (isWin) rootPath.resolve("swiftsrc2cpg.bat") else rootPath.resolve("swiftsrc2cpg.sh") - private var typeRecoveryConfig: XTypeRecoveryConfig = uninitialized + + private lazy val typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. */ @@ -29,7 +29,6 @@ case class SwiftSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends private def invoke(inputPath: String, outputPath: String): Try[String] = { val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - typeRecoveryConfig = XTypeRecoveryConfig.parse(config.cmdLineParams.toSeq) runShellCommand(command.toString, arguments).map(_ => outputPath) } diff --git a/console/src/main/scala/io/joern/console/cpgcreation/package.scala b/console/src/main/scala/io/joern/console/cpgcreation/package.scala index 43712f5df7b8..dd8bbb3c1ccf 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/package.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/package.scala @@ -19,12 +19,13 @@ package object cpgcreation { ): Option[CpgGenerator] = { lazy val conf = config.withArgs(args) language match { - case Languages.CSHARP | Languages.CSHARPSRC => Some(CSharpCpgGenerator(conf, rootPath)) - case Languages.C | Languages.NEWC => Some(CCpgGenerator(conf, rootPath)) - case Languages.LLVM => Some(LlvmCpgGenerator(conf, rootPath)) - case Languages.GOLANG => Some(GoCpgGenerator(conf, rootPath)) - case Languages.JAVA => Some(JavaCpgGenerator(conf, rootPath)) - case Languages.JAVASRC => Some(JavaSrcCpgGenerator(conf, rootPath)) + case Languages.CSHARP => Some(CSharpCpgGenerator(conf, rootPath)) + case Languages.CSHARPSRC => Some(CSharpSrcCpgGenerator(conf, rootPath)) + case Languages.C | Languages.NEWC => Some(CCpgGenerator(conf, rootPath)) + case Languages.LLVM => Some(LlvmCpgGenerator(conf, rootPath)) + case Languages.GOLANG => Some(GoCpgGenerator(conf, rootPath)) + case Languages.JAVA => Some(JavaCpgGenerator(conf, rootPath)) + case Languages.JAVASRC => Some(JavaSrcCpgGenerator(conf, rootPath)) case Languages.JSSRC | Languages.JAVASCRIPT => val jssrc = JsSrcCpgGenerator(conf, rootPath) if (jssrc.isAvailable) Some(jssrc) diff --git a/console/src/main/scala/io/joern/console/package.scala b/console/src/main/scala/io/joern/console/package.scala index a29435afba29..958f26202bde 100644 --- a/console/src/main/scala/io/joern/console/package.scala +++ b/console/src/main/scala/io/joern/console/package.scala @@ -1,6 +1,6 @@ package io.joern -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table.AvailableWidthProvider import replpp.Operators.* import replpp.Colors diff --git a/console/src/main/scala/io/joern/console/scan/package.scala b/console/src/main/scala/io/joern/console/scan/package.scala index 41e01d2e32e4..31230076bb8d 100644 --- a/console/src/main/scala/io/joern/console/scan/package.scala +++ b/console/src/main/scala/io/joern/console/scan/package.scala @@ -11,11 +11,6 @@ package object scan { private val logger: Logger = LoggerFactory.getLogger(this.getClass) - implicit class ScannerStarters(val cpg: Cpg) extends AnyVal { - def finding: Iterator[Finding] = - overflowdb.traversal.InitialTraversal.from[Finding](cpg.graph, NodeTypes.FINDING) - } - implicit class QueryWrapper(q: Query) { /** Obtain list of findings by running query on CPG diff --git a/console/src/main/scala/io/joern/console/workspacehandling/Project.scala b/console/src/main/scala/io/joern/console/workspacehandling/Project.scala index 8de086cd693b..68d2ddc8780e 100644 --- a/console/src/main/scala/io/joern/console/workspacehandling/Project.scala +++ b/console/src/main/scala/io/joern/console/workspacehandling/Project.scala @@ -1,6 +1,6 @@ package io.joern.console.workspacehandling -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.Overlays diff --git a/console/src/main/scala/io/joern/console/workspacehandling/Workspace.scala b/console/src/main/scala/io/joern/console/workspacehandling/Workspace.scala index 0d3a068d4bc2..82caaff7897b 100644 --- a/console/src/main/scala/io/joern/console/workspacehandling/Workspace.scala +++ b/console/src/main/scala/io/joern/console/workspacehandling/Workspace.scala @@ -1,8 +1,8 @@ package io.joern.console.workspacehandling import io.joern.console.defaultAvailableWidthProvider -import overflowdb.traversal.help.Table -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table +import flatgraph.help.Table.AvailableWidthProvider import scala.collection.mutable.ListBuffer diff --git a/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceLoader.scala b/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceLoader.scala index c0297fb43551..4d7d943db5e1 100644 --- a/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceLoader.scala +++ b/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceLoader.scala @@ -2,7 +2,7 @@ package io.joern.console.workspacehandling import better.files.Dsl.mkdirs import better.files.File -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table.AvailableWidthProvider import java.nio.file.Path import scala.collection.mutable.ListBuffer diff --git a/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceManager.scala b/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceManager.scala index 723a058617dc..3beb2345060a 100644 --- a/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceManager.scala +++ b/console/src/main/scala/io/joern/console/workspacehandling/WorkspaceManager.scala @@ -1,18 +1,17 @@ package io.joern.console.workspacehandling -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.joern.console import io.joern.console.defaultAvailableWidthProvider import io.joern.console.Reporting import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, CpgLoaderConfig} +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader import org.json4s.DefaultFormats -import org.json4s.native.Serialization.{write => jsonWrite} -import overflowdb.Config +import org.json4s.native.Serialization.write as jsonWrite import java.net.URLEncoder -import java.nio.file.Path +import java.nio.file.{Path, Paths} import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -305,17 +304,12 @@ class WorkspaceManager[ProjectType <: Project](path: String, loader: WorkspaceLo private def loadCpgRaw(cpgFilename: String): Option[Cpg] = { Try { - val odbConfig = Config.withDefaults.withStorageLocation(cpgFilename) - val config = - CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - val newCpg = CpgLoader.loadFromOverflowDb(config) - CpgLoader.createIndexes(newCpg) - newCpg + CpgLoader.load(cpgFilename) } match { case Success(v) => Some(v) case Failure(ex) => System.err.println("Error loading CPG") - System.err.println(ex) + ex.printStackTrace() None } } diff --git a/console/src/test/scala/io/joern/console/ConsoleTests.scala b/console/src/test/scala/io/joern/console/ConsoleTests.scala index 1114558d56fb..728a83d55331 100644 --- a/console/src/test/scala/io/joern/console/ConsoleTests.scala +++ b/console/src/test/scala/io/joern/console/ConsoleTests.scala @@ -1,11 +1,11 @@ package io.joern.console -import better.files.Dsl._ -import better.files._ -import io.joern.console.testing._ +import better.files.Dsl.* +import better.files.* +import io.joern.console.testing.* import io.joern.x2cpg.X2Cpg.defaultOverlayCreators import io.joern.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -428,7 +428,7 @@ class ConsoleTests extends AnyWordSpec with Matchers { "cpg" should { "provide .help command" in ConsoleFixture() { (console, codeDir) => // part of Predefined.shared, which makes the below work in the repl without separate import - import io.shiftleft.codepropertygraph.Cpg.docSearchPackages + import io.shiftleft.semanticcpg.language.docSearchPackages import io.joern.console.testing.availableWidthProvider console.importCode(codeDir.toString) diff --git a/console/src/test/scala/io/joern/console/LanguageHelperTests.scala b/console/src/test/scala/io/joern/console/LanguageHelperTests.scala index 80b145767aea..afe860453215 100644 --- a/console/src/test/scala/io/joern/console/LanguageHelperTests.scala +++ b/console/src/test/scala/io/joern/console/LanguageHelperTests.scala @@ -1,7 +1,7 @@ package io.joern.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.shiftleft.codepropertygraph.generated.Languages import io.joern.console.cpgcreation.{guessLanguage, LlvmCpgGenerator} import org.scalatest.matchers.should.Matchers diff --git a/console/src/test/scala/io/joern/console/PluginManagerTests.scala b/console/src/test/scala/io/joern/console/PluginManagerTests.scala index 9a0957769c27..63eb5d2b933c 100644 --- a/console/src/test/scala/io/joern/console/PluginManagerTests.scala +++ b/console/src/test/scala/io/joern/console/PluginManagerTests.scala @@ -1,7 +1,7 @@ package io.joern.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.shiftleft.utils.ProjectRoot import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/console/src/test/scala/io/joern/console/testing/package.scala b/console/src/test/scala/io/joern/console/testing/package.scala index 897bb9e9633c..1b2053d82a7c 100644 --- a/console/src/test/scala/io/joern/console/testing/package.scala +++ b/console/src/test/scala/io/joern/console/testing/package.scala @@ -3,7 +3,7 @@ package io.joern.console import better.files.Dsl.* import better.files.* import io.joern.console.workspacehandling.Project -import overflowdb.traversal.help.Table.{AvailableWidthProvider, ConstantWidth} +import flatgraph.help.Table.{AvailableWidthProvider, ConstantWidth} import scala.util.Try diff --git a/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceManagerTests.scala b/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceManagerTests.scala index 4d1ddb669d49..867f2e4cfd06 100644 --- a/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceManagerTests.scala +++ b/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceManagerTests.scala @@ -1,6 +1,6 @@ package io.joern.console.workspacehandling -import better.files._ +import better.files.* import io.shiftleft.codepropertygraph.generated.Cpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceTests.scala b/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceTests.scala index 0abe2f0887d2..9d1b5f48b3ef 100644 --- a/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceTests.scala +++ b/console/src/test/scala/io/joern/console/workspacehandling/WorkspaceTests.scala @@ -1,6 +1,6 @@ package io.joern.console.workspacehandling -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File import io.joern.console.testing.availableWidthProvider import io.shiftleft.semanticcpg.testing.MockCpg diff --git a/console/src/test/scala/io/shiftleft/codepropertygraph/cpgloading/TestProtoCpg.scala b/console/src/test/scala/io/shiftleft/codepropertygraph/cpgloading/TestProtoCpg.scala deleted file mode 100644 index 7649f012eccd..000000000000 --- a/console/src/test/scala/io/shiftleft/codepropertygraph/cpgloading/TestProtoCpg.scala +++ /dev/null @@ -1,45 +0,0 @@ -package io.shiftleft.codepropertygraph.cpgloading - -import better.files.File -import io.shiftleft.proto.cpg.Cpg -import io.shiftleft.proto.cpg.Cpg.CpgStruct - -import java.io.FileOutputStream - -object TestProtoCpg { - - def createTestProtoCpg: File = { - val outDir = better.files.File.newTemporaryDirectory("cpgloadertests") - val outStream = new FileOutputStream((outDir / "1.proto").pathAsString) - CpgStruct - .newBuilder() - .addNode( - CpgStruct.Node - .newBuilder() - .setKey(1) - .setType(CpgStruct.Node.NodeType.valueOf("METHOD")) - .addProperty( - CpgStruct.Node.Property - .newBuilder() - .setName(Cpg.NodePropertyName.valueOf("FULL_NAME")) - .setValue( - Cpg.PropertyValue - .newBuilder() - .setStringValue("foo") - .build() - ) - .build - ) - .build() - ) - .build() - .writeTo(outStream) - outStream.close() - - val zipFile = better.files.File.newTemporaryFile("cpgloadertests", ".bin.zip") - outDir.zipTo(zipFile) - outDir.delete() - zipFile - } - -} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/DefaultSemantics.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/DefaultSemantics.scala index fd8a1df0b264..49fad8d422cd 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/DefaultSemantics.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/DefaultSemantics.scala @@ -1,6 +1,6 @@ package io.joern.dataflowengineoss -import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, PassThroughMapping, Semantics} +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, PassThroughMapping, FullNameSemantics} import io.shiftleft.codepropertygraph.generated.Operators import scala.annotation.unused @@ -10,9 +10,9 @@ object DefaultSemantics { /** @return * a default set of common external procedure calls for all languages. */ - def apply(): Semantics = { + def apply(): FullNameSemantics = { val list = operatorFlows ++ cFlows ++ javaFlows - Semantics.fromList(list) + FullNameSemantics.fromList(list) } private def F = (x: String, y: List[(Int, Int)]) => FlowSemantic.from(x, y) @@ -43,6 +43,8 @@ object DefaultSemantics { F(Operators.notNullAssert, List((1, -1))), F(Operators.fieldAccess, List((1, -1))), F(Operators.getElementPtr, List((1, -1))), + PTF(Operators.modulo, List.empty), + PTF(Operators.arrayInitializer, List.empty), // TODO does this still exist? F(".incBy", List((1, 1), (2, 1), (3, 1), (4, 1))), @@ -60,18 +62,6 @@ object DefaultSemantics { F(Operators.preIncrement, List((1, 1), (1, -1))), F(Operators.sizeOf, List.empty[(Int, Int)]), - // some of those operators have duplicate mappings due to a typo - // - see https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1630 - - F(".assignmentExponentiation", List((2, 1), (1, 1))), - F(".assignmentModulo", List((2, 1), (1, 1))), - F(".assignmentShiftLeft", List((2, 1), (1, 1))), - F(".assignmentLogicalShiftRight", List((2, 1), (1, 1))), - F(".assignmentArithmeticShiftRight", List((2, 1), (1, 1))), - F(".assignmentAnd", List((2, 1), (1, 1))), - F(".assignmentOr", List((2, 1), (1, 1))), - F(".assignmentXor", List((2, 1), (1, 1))), - // Language specific operators PTF(".tupleLiteral"), PTF(".dictLiteral"), @@ -157,6 +147,6 @@ object DefaultSemantics { * procedure semantics for operators and common external Java calls only. */ @unused - def javaSemantics(): Semantics = Semantics.fromList(operatorFlows ++ javaFlows) + def javaSemantics(): FullNameSemantics = FullNameSemantics.fromList(operatorFlows ++ javaFlows) } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/dotgenerator/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/dotgenerator/DdgGenerator.scala index 61c55435add3..ca007968914d 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/dotgenerator/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/dotgenerator/DdgGenerator.scala @@ -1,15 +1,13 @@ package io.joern.dataflowengineoss.dotgenerator import io.joern.dataflowengineoss.DefaultSemantics -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Properties} -import io.joern.dataflowengineoss.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.utils.MemberAccess.isGenericMemberAccessName -import overflowdb.Node -import overflowdb.traversal.jIteratortoTraversal import scala.collection.mutable @@ -59,7 +57,7 @@ class DdgGenerator { } } - private def shouldBeDisplayed(v: Node): Boolean = !( + private def shouldBeDisplayed(v: StoredNode): Boolean = !( v.isInstanceOf[ControlStructure] || v.isInstanceOf[JumpTarget] ) @@ -91,7 +89,16 @@ class DdgGenerator { val allInEdges = v .inE(EdgeTypes.REACHING_DEF) .map(x => - Edge(x.outNode.asInstanceOf[StoredNode], v, srcVisible = true, x.property(Properties.Variable), edgeType) + // note: this looks strange, but let me explain... + // in overflowdb, edges were allowed multiple properties and this used to be `x.property(Properties.VARIABLE)` + // in flatgraph an edge may have zero or one properties and they're not named... + // in this case we know that we're dealing with ReachingDef edges which has the `variable` property + val variablePropertyMaybe = x.property match { + case null => null + case variableProperty: String => variableProperty + case _ => null + } + Edge(x.src.asInstanceOf[StoredNode], v, srcVisible = true, variablePropertyMaybe, edgeType) ) v match { diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/Path.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/Path.scala index 241a3d95750d..af5d96936b3d 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/Path.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/Path.scala @@ -3,8 +3,8 @@ package io.joern.dataflowengineoss.language import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, CfgNode, Member, MethodParameterIn} import io.shiftleft.semanticcpg import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Table -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table +import flatgraph.help.Table.AvailableWidthProvider case class Path(elements: List[AstNode]) { def resultPairs(): List[(String, Option[Int])] = { @@ -26,7 +26,7 @@ object Path { availableWidthProvider: AvailableWidthProvider = semanticcpg.defaultAvailableWidthProvider ): Show[Path] = { path => val table = Table( - columnNames = Array("nodeType", "tracked", "line", "method", "file"), + columnNames = Seq("nodeType", "tracked", "line", "method", "file"), rows = path.elements.map { astNode => val nodeType = astNode.getClass.getSimpleName val lineNumber = astNode.lineNumber.getOrElse("N/A").toString @@ -36,7 +36,7 @@ object Path { case member: Member => val tracked = member.name val methodName = "" - Array(nodeType, tracked, lineNumber, methodName, fileName) + Seq(nodeType, tracked, lineNumber, methodName, fileName) case cfgNode: CfgNode => val method = cfgNode.method val methodName = method.name @@ -46,7 +46,7 @@ object Path { s"$methodName($paramsPretty)" case _ => cfgNode.statement.repr } - Array(nodeType, statement, lineNumber, methodName, fileName) + Seq(nodeType, statement, lineNumber, methodName, fileName) } } ) diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExpressionMethods.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExpressionMethods.scala index 52a66ffecbfa..362cefee6d46 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExpressionMethods.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExpressionMethods.scala @@ -64,8 +64,8 @@ class ExpressionMethods[NodeType <: Expression](val node: NodeType) extends AnyV srcIndex == node.argumentIndex && dstName == tgt.argumentName.get case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(dstIndex, _)) => srcIndex == node.argumentIndex && dstIndex == tgt.argumentIndex - case PassThroughMapping if tgt.argumentIndex == node.argumentIndex || tgt.argumentIndex == -1 => true - case _ => false + case PassThroughMapping => node.argumentIndex == tgt.argumentIndex && node.argumentName == tgt.argumentName + case _ => false } } } @@ -73,9 +73,7 @@ class ExpressionMethods[NodeType <: Expression](val node: NodeType) extends AnyV /** Retrieve flow semantic for the call this argument is a part of. */ def semanticsForCallByArg(implicit semantics: Semantics): Iterator[FlowSemantic] = { - argToMethods(node).flatMap { method => - semantics.forMethod(method.fullName) - } + argToMethods(node).flatMap(semantics.forMethod) } private def argToMethods(arg: Expression): Iterator[Method] = { diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala index fef9acf3c1bd..7a914187e6b6 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala @@ -10,7 +10,7 @@ import io.shiftleft.semanticcpg.language.* import scala.collection.mutable import scala.jdk.CollectionConverters.* -class ExtendedCfgNodeMethods[NodeType <: CfgNode](val node: NodeType) extends AnyVal { +class ExtendedCfgNodeMethods[CfgNodeType <: CfgNode](val node: CfgNodeType) extends AnyVal { /** Convert to nearest AST node */ diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala index 842ce65d0d74..7101c872177f 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import io.joern.dataflowengineoss.language.dotextension.DdgNodeDot import io.joern.dataflowengineoss.language.nodemethods.{ExpressionMethods, ExtendedCfgNodeMethods} -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc import scala.language.implicitConversions diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpCpg14.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpCpg14.scala index 16ea1b420cb0..4575fba5ab9b 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpCpg14.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpCpg14.scala @@ -2,9 +2,9 @@ package io.joern.dataflowengineoss.layers.dataflows import better.files.File import io.joern.dataflowengineoss.DefaultSemantics -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class Cpg14DumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpDdg.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpDdg.scala index ec5db89d16c6..c76e5b9275d5 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpDdg.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpDdg.scala @@ -2,9 +2,9 @@ package io.joern.dataflowengineoss.layers.dataflows import better.files.File import io.joern.dataflowengineoss.DefaultSemantics -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class DdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpPdg.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpPdg.scala index 2221a5372c8d..3b3dc6d24d30 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpPdg.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/DumpPdg.scala @@ -2,9 +2,9 @@ package io.joern.dataflowengineoss.layers.dataflows import better.files.File import io.joern.dataflowengineoss.DefaultSemantics -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class PdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/OssDataFlow.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/OssDataFlow.scala index ce0fa3800eef..60dbeef7a526 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/OssDataFlow.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/layers/dataflows/OssDataFlow.scala @@ -2,7 +2,7 @@ package io.joern.dataflowengineoss.layers.dataflows import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.passes.reachingdef.ReachingDefPass -import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, FullNameSemantics, Semantics} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object OssDataFlow { @@ -12,23 +12,15 @@ object OssDataFlow { def defaultOpts = new OssDataFlowOptions() } -class OssDataFlowOptions( - var maxNumberOfDefinitions: Int = 4000, - var extraFlows: List[FlowSemantic] = List.empty[FlowSemantic] -) extends LayerCreatorOptions {} +class OssDataFlowOptions(var maxNumberOfDefinitions: Int = 4000, var semantics: Semantics = DefaultSemantics()) + extends LayerCreatorOptions {} -class OssDataFlow(opts: OssDataFlowOptions)(implicit - s: Semantics = Semantics.fromList(DefaultSemantics().elements ++ opts.extraFlows) -) extends LayerCreator { +class OssDataFlow(opts: OssDataFlowOptions)(implicit val semantics: Semantics = opts.semantics) extends LayerCreator { override val overlayName: String = OssDataFlow.overlayName override val description: String = OssDataFlow.description override def create(context: LayerCreatorContext): Unit = { - val cpg = context.cpg - val enhancementExecList = Iterator(new ReachingDefPass(cpg, opts.maxNumberOfDefinitions)) - enhancementExecList.zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, index) - } + ReachingDefPass(context.cpg, opts.maxNumberOfDefinitions).createAndApply() } } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/DdgGenerator.scala index 887531648c7e..7693f28757ff 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/DdgGenerator.scala @@ -4,10 +4,10 @@ import io.joern.dataflowengineoss.{globalFromLiteral, identifierToFirstUsages} import io.joern.dataflowengineoss.queryengine.AccessPathUsage.toTrackedBaseAndAccessPathSimple import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators} import io.shiftleft.semanticcpg.accesspath.MatchResult import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import scala.collection.{Set, mutable} @@ -132,7 +132,7 @@ class DdgGenerator(semantics: Semantics) { // There is always an edge from the method input parameter // to the corresponding method output parameter as modifications // of the input parameter only affect a copy. - paramOut.paramIn.foreach { paramIn => + paramOut.start.paramIn.foreach { paramIn => addEdge(paramIn, paramOut, paramIn.name) } usageAnalyzer.usedIncomingDefs(paramOut).foreach { case (_, inElements) => @@ -224,7 +224,7 @@ class DdgGenerator(semantics: Semantics) { (fromNode, toNode) match { case (parentNode: CfgNode, childNode: CfgNode) if EdgeValidator.isValidEdge(childNode, parentNode) => - dstGraph.addEdge(fromNode, toNode, EdgeTypes.REACHING_DEF, PropertyNames.VARIABLE, variable) + dstGraph.addEdge(fromNode, toNode, EdgeTypes.REACHING_DEF, variable) case _ => } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/EdgeValidator.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/EdgeValidator.scala index 9f8666fd6848..505bd4e3b7fc 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/EdgeValidator.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/EdgeValidator.scala @@ -1,7 +1,7 @@ package io.joern.dataflowengineoss.passes.reachingdef import io.joern.dataflowengineoss.language.* -import io.joern.dataflowengineoss.queryengine.Engine.isOutputArgOfInternalMethod +import io.joern.dataflowengineoss.queryengine.Engine.{isOutputArgOfInternalMethod, semanticsForCall} import io.joern.dataflowengineoss.semanticsloader.{ FlowMapping, FlowPath, @@ -50,7 +50,7 @@ object EdgeValidator { */ private def isCallRetval(parentNode: StoredNode)(implicit semantics: Semantics): Boolean = parentNode match { - case call: Call => semantics.forMethod(call.methodFullName).exists(!explicitlyFlowsToReturnValue(_)) + case call: Call => semanticsForCall(call).exists(!explicitlyFlowsToReturnValue(_)) case _ => false } @@ -58,8 +58,10 @@ object EdgeValidator { flowSemantic.mappings.exists(explicitlyFlowsToReturnValue) private def explicitlyFlowsToReturnValue(flowPath: FlowPath): Boolean = flowPath match { - case FlowMapping(_, ParameterNode(dst, _)) => dst == -1 - case PassThroughMapping => true - case _ => false + // Some frontends (e.g. python) denote named arguments using `-1` as the argument index. As such + // `-1` denotes the return value only if there's no argument name. + case FlowMapping(_, ParameterNode(-1, None)) => true + case PassThroughMapping => true + case _ => false } } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala index a5aab780ed00..a01796831c3c 100755 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala @@ -2,9 +2,9 @@ package io.joern.dataflowengineoss.passes.reachingdef import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable @@ -16,7 +16,7 @@ class ReachingDefPass(cpg: Cpg, maxNumberOfDefinitions: Int = 4000)(implicit s: private val logger: Logger = LoggerFactory.getLogger(this.getClass) // If there are any regex method full names, load them early - s.loadRegexSemantics(cpg) + s.initialize(cpg) override def generateParts(): Array[Method] = cpg.method.toArray diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala index 936f43ef4b9f..e06bb34447e0 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala @@ -167,7 +167,7 @@ class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) val gen: Map[StoredNode, mutable.BitSet] = initGen(method).withDefaultValue(mutable.BitSet()) - val kill: Map[StoredNode, Set[Definition]] = + val kill: Map[StoredNode, mutable.BitSet] = initKill(method, gen).withDefaultValue(mutable.BitSet()) /** For a given flow graph node `n` and set of definitions, apply the transfer function to obtain the updated set of @@ -224,7 +224,7 @@ class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) * All operations in our graph are represented by calls and non-operations such as identifiers or field-identifiers * have empty gen and kill sets, meaning that they just pass on definitions unaltered. */ - private def initKill(method: Method, gen: Map[StoredNode, Set[Definition]]): Map[StoredNode, Set[Definition]] = { + private def initKill(method: Method, gen: Map[StoredNode, mutable.BitSet]): Map[StoredNode, mutable.BitSet] = { val allIdentifiers: Map[String, List[CfgNode]] = { val results = mutable.Map.empty[String, List[CfgNode]] @@ -259,42 +259,44 @@ class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) * calculate kill(call) based on gen(call). */ private def killsForGens( - genOfCall: Set[Definition], + genOfCall: mutable.BitSet, allIdentifiers: Map[String, List[CfgNode]], allCalls: Map[String, List[Call]] - ): Set[Definition] = { + ): mutable.BitSet = { - def definitionsOfSameVariable(definition: Definition): Set[Definition] = { + def definitionsOfSameVariable(definition: Definition): Iterator[Definition] = { val definedNodes = flowGraph.numberToNode(definition) match { case param: MethodParameterIn => - allIdentifiers(param.name) + allIdentifiers(param.name).iterator .filter(x => x.id != param.id) case identifier: Identifier => - val sameIdentifiers = allIdentifiers(identifier.name) + val sameIdentifiers = allIdentifiers(identifier.name).iterator .filter(x => x.id != identifier.id) /** Killing an identifier should also kill field accesses on that identifier. For example, a reassignment `x = * new Box()` should kill any previous calls to `x.value`, `x.length()`, etc. */ - val sameObjects: Iterable[Call] = allCalls.values.flatten + val sameObjects: Iterator[Call] = allCalls.valuesIterator.flatten .filter(_.name == Operators.fieldAccess) .filter(_.ast.isIdentifier.nameExact(identifier.name).nonEmpty) sameIdentifiers ++ sameObjects case call: Call => - allCalls(call.code) + allCalls(call.code).iterator .filter(x => x.id != call.id) - case _ => Set() + case _ => Iterator.empty } definedNodes // It can happen that the CFG is broken and contains isolated nodes, // in which case they are not in `nodeToNumber`. Let's filter those. - .collect { case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) }.toSet + .collect { case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) } } - genOfCall.flatMap { definition => - definitionsOfSameVariable(definition) + val res = mutable.BitSet() + for (definition <- genOfCall) { + res.addAll(definitionsOfSameVariable(definition)) } + res } } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsage.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsage.scala index 341f29621d53..534b8bb565b8 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsage.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsage.scala @@ -1,8 +1,8 @@ package io.joern.dataflowengineoss.queryengine -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.accesspath._ -import io.shiftleft.semanticcpg.language.{AccessPathHandling, toCallMethods} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.accesspath.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.utils.MemberAccess import org.slf4j.LoggerFactory diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/Engine.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/Engine.scala index 1c0c279a9758..0e963c9c8aaf 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/Engine.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/Engine.scala @@ -1,14 +1,14 @@ package io.joern.dataflowengineoss.queryengine +import flatgraph.Edge import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.passes.reachingdef.EdgeValidator import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Properties} +import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} -import overflowdb.Edge import java.util.concurrent.* import scala.collection.mutable @@ -205,9 +205,11 @@ object Engine { private def elemForEdge(e: Edge, callSiteStack: List[Call] = List())(implicit semantics: Semantics ): Option[PathElement] = { - val curNode = e.inNode().asInstanceOf[CfgNode] - val parNode = e.outNode().asInstanceOf[CfgNode] - val outLabel = Some(e.property(Properties.Variable)).getOrElse("") + val curNode = e.dst.asInstanceOf[CfgNode] + val parNode = e.src.asInstanceOf[CfgNode] + // note: flatgraph only allows at most one property per edge, and since we know :tm: that this is a ReachingDef edge it must be the Variable property... + val variablePropertyMaybe = Option(e.property).map(_.asInstanceOf[String]) + val outLabel = variablePropertyMaybe.getOrElse("") if (!EdgeValidator.isValidEdge(curNode, parNode)) { return None @@ -254,9 +256,8 @@ object Engine { private def ddgInE(node: CfgNode, path: Vector[PathElement], callSiteStack: List[Call] = List()): Vector[Edge] = { node .inE(EdgeTypes.REACHING_DEF) - .asScala .filter { e => - e.outNode() match { + e.src match { case srcNode: CfgNode => !srcNode.isInstanceOf[Method] && !path .map(x => x.node) @@ -291,9 +292,7 @@ object Engine { } def semanticsForCall(call: Call)(implicit semantics: Semantics): List[FlowSemantic] = { - Engine.methodsForCall(call).flatMap { method => - semantics.forMethod(method.fullName) - } + Engine.methodsForCall(call).flatMap(semantics.forMethod) } } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala index f0b0cf5fc9e1..46326fbc0e6c 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala @@ -1,7 +1,7 @@ package io.joern.dataflowengineoss.queryengine import scala.collection.mutable -import scala.collection.parallel.CollectionConverters._ +import scala.collection.parallel.CollectionConverters.* /** Complete held tasks using the result table. The result table is modified in the process. * diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala index cac222fbd476..74b651ef428a 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala @@ -32,7 +32,7 @@ object SourcesToStartingPoints { .map(src => { // We need to get Cpg wrapper from graph. Hence we are taking head element from source iterator. // This will also ensure if the source list is empty then these tasks are invoked. - val cpg = Cpg(src.graph()) + val cpg = Cpg(src.graph) val (startingPoints, methodTasks) = calculateStartingPoints(sources, executorService) val startingPointFromUsageInOtherClasses = calculateStatingPointsWithUsageInOtherClasses(methodTasks, cpg, executorService) @@ -147,7 +147,10 @@ class SourceToStartingPointsInMethod( private def usageInOtherClasses(m: Method, usageInputs: List[UsageInput]): List[StartingPointWithSource] = { usageInputs.flatMap { case UsageInput(src, typeDecl, astNode) => m.fieldAccess - .where(_.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName)) + .or( + _.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName), + _.argument(1).isTypeRef.typeFullNameExact(typeDecl.fullName) + ) .where { x => astNode match { case identifier: Identifier => @@ -189,7 +192,8 @@ abstract class BaseSourceToStartingPoints extends Callable[Unit] { protected def sourceToStartingPoints(src: StoredNode): (List[CfgNode], List[UsageInput]) = { src match { case methodReturn: MethodReturn => - (methodReturn.method.callIn.l, Nil) + // n.b. there's a generated `callIn` step that we really want to use, but it's shadowed by `MethodTraversal.callIn` + (methodReturn.method._callIn.cast[Call].l, Nil) case lit: Literal => val usageInput = targetsToClassIdentifierPair(literalToInitializedMembers(lit), src) val uses = usages(usageInput) diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskCreator.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskCreator.scala index f10f105d2409..ebf5909b894d 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskCreator.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskCreator.scala @@ -1,17 +1,8 @@ package io.joern.dataflowengineoss.queryengine import io.joern.dataflowengineoss.queryengine.Engine.argToOutputParams -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.{ - Call, - Expression, - Method, - MethodParameterIn, - MethodParameterOut, - MethodRef, - Return -} -import io.shiftleft.semanticcpg.language.NoResolve +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{Cpg, Languages} import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} @@ -98,9 +89,15 @@ class TaskCreator(context: EngineContext) { * `m`, return `foo` in `foo.bar(m)` TODO: I'm not sure whether `methodRef.methodFullNameExact(...)` uses an index. * If not, then caching these lookups or keeping a map of all method names to their references may make sense. */ - - private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = - new Cpg(param.graph()).methodRef.methodFullNameExact(param.method.fullName).inCall.argument(0).l + private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = { + val cpg = new Cpg(param.graph) + def trav = cpg.methodRef.methodFullNameExact(param.method.fullName).inCall + cpg.metaData.language.headOption match { + // Kotlin higher-level functions are often static and don't have the arg0 recv + case Some(Languages.KOTLIN) => trav.argument(1).l + case _ => trav.argument(0).l + } + } /** Create new tasks from all results that end in an output argument, including return arguments. In this case, we * want to traverse to corresponding method output parameters and method return nodes respectively. diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala index 263251775694..784ab9ee45a5 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala @@ -2,8 +2,8 @@ package io.joern.dataflowengineoss.queryengine import io.joern.dataflowengineoss.queryengine.QueryEngineStatistics.{PATH_CACHE_HITS, PATH_CACHE_MISSES} import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language.{toCfgNodeMethods, toExpressionMethods, _} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import java.util.concurrent.Callable import scala.collection.mutable diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemantics.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemantics.scala new file mode 100644 index 000000000000..3a4196a9c09c --- /dev/null +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemantics.scala @@ -0,0 +1,79 @@ +package io.joern.dataflowengineoss.semanticsloader + +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.Method +import io.shiftleft.semanticcpg.language.* +import org.slf4j.LoggerFactory + +import scala.collection.mutable + +object FullNameSemantics { + + private val logger = LoggerFactory.getLogger(getClass) + + /** Builds FullNameSemantics given their constituent FlowSemantics. Same methodFullNamed FlowSemantic elements are + * combined into a single one with both of their FlowMappings. + */ + def fromList(elements: List[FlowSemantic]): FullNameSemantics = FullNameSemantics( + elements.groupBy(_.methodFullName).map { (fullName, semantics) => + val howMany = semantics.length + if (howMany > 1) { + logger.warn(s"$howMany competing FlowSemantics found for $fullName, merging them") + } + fullName -> FlowSemantic( + methodFullName = fullName, + mappings = semantics.flatMap(_.mappings), + regex = semantics.exists(_.regex) + ) + } + ) + + def empty: FullNameSemantics = fromList(List()) + +} + +class FullNameSemantics private (methodToSemantic: Map[String, FlowSemantic]) extends Semantics { + + /** The map below keeps a mapping between results of a regex and the regex string it matches. e.g. + * + * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` + */ + private val regexMatchedFullNames = mutable.HashMap.empty[String, String] + + /** Initialize all the method semantics that use regex with all their regex results before query time. + */ + override def initialize(cpg: Cpg): Unit = { + import io.shiftleft.semanticcpg.language._ + + methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => + cpg.method.fullName(regexString).fullName.foreach { methodMatch => + regexMatchedFullNames.put(methodMatch, regexString) + } + } + } + + def elements: List[FlowSemantic] = methodToSemantic.values.toList + + private def forMethod(fullName: String): Option[FlowSemantic] = regexMatchedFullNames.get(fullName) match { + case Some(matchedFullName) => methodToSemantic.get(matchedFullName) + case None => methodToSemantic.get(fullName) + } + + override def forMethod(method: Method): Option[FlowSemantic] = forMethod(method.fullName) + + def serialize: String = { + elements + .sortBy(_.methodFullName) + .map { elem => + s"\"${elem.methodFullName}\" " + elem.mappings + .collect { case FlowMapping(x, y) => s"$x -> $y" } + .mkString(" ") + } + .mkString("\n") + } + + /** Immutably extends the current `FullNameSemantics` with `extraFlows`. + */ + def plus(extraFlows: List[FlowSemantic]): FullNameSemantics = FullNameSemantics.fromList(elements ++ extraFlows) + +} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParser.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParser.scala new file mode 100644 index 000000000000..669495da8a39 --- /dev/null +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParser.scala @@ -0,0 +1,73 @@ +package io.joern.dataflowengineoss.semanticsloader + +import io.joern.dataflowengineoss.SemanticsParser.MappingContext +import io.joern.dataflowengineoss.{SemanticsBaseListener, SemanticsLexer, SemanticsParser} +import org.antlr.v4.runtime.tree.ParseTreeWalker +import org.antlr.v4.runtime.{CharStream, CharStreams, CommonTokenStream} + +import scala.collection.mutable +import scala.jdk.CollectionConverters.* + +class FullNameSemanticsParser { + + def parse(input: String): List[FlowSemantic] = { + val charStream = CharStreams.fromString(input) + parseCharStream(charStream) + } + + def parseFile(fileName: String): List[FlowSemantic] = { + val charStream = CharStreams.fromFileName(fileName) + parseCharStream(charStream) + } + + private def parseCharStream(charStream: CharStream): List[FlowSemantic] = { + val lexer = new SemanticsLexer(charStream) + val tokenStream = new CommonTokenStream(lexer) + val parser = new SemanticsParser(tokenStream) + val treeWalker = new ParseTreeWalker() + + val tree = parser.taintSemantics() + val listener = new Listener() + treeWalker.walk(listener, tree) + listener.result.toList + } + + implicit class AntlrFlowExtensions(val ctx: MappingContext) { + + def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined + + def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt + + def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) + + def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt + + def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) + + } + + private class Listener extends SemanticsBaseListener { + + val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() + + override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = { + ctx.singleSemantic().asScala.foreach { semantic => + val methodName = semantic.methodName().name().getText + val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) + result.addOne(FlowSemantic(methodName, mappings)) + } + } + + private def ctxToParamMapping(ctx: MappingContext): FlowPath = + if (ctx.isPassThrough) { + PassThroughMapping + } else { + val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) + val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) + + FlowMapping(src, dst) + } + + } + +} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Parser.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Parser.scala deleted file mode 100644 index f3e245d575c3..000000000000 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Parser.scala +++ /dev/null @@ -1,211 +0,0 @@ -package io.joern.dataflowengineoss.semanticsloader - -import io.joern.dataflowengineoss.SemanticsParser.MappingContext -import io.joern.dataflowengineoss.{SemanticsBaseListener, SemanticsLexer, SemanticsParser} -import io.shiftleft.codepropertygraph.generated.Cpg -import org.antlr.v4.runtime.tree.ParseTreeWalker -import org.antlr.v4.runtime.{CharStream, CharStreams, CommonTokenStream} - -import scala.collection.mutable -import scala.jdk.CollectionConverters._ - -object Semantics { - - def fromList(elements: List[FlowSemantic]): Semantics = { - new Semantics( - mutable.Map.newBuilder - .addAll(elements.map { e => - e.methodFullName -> e - }) - .result() - ) - } - - def empty: Semantics = fromList(List()) - -} - -class Semantics private (methodToSemantic: mutable.Map[String, FlowSemantic]) { - - /** The map below keeps a mapping between results of a regex and the regex string it matches. e.g. - * - * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` - */ - private val regexMatchedFullNames = mutable.HashMap.empty[String, String] - - /** Initialize all the method semantics that use regex with all their regex results before query time. - */ - def loadRegexSemantics(cpg: Cpg): Unit = { - import io.shiftleft.semanticcpg.language._ - - methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => - cpg.method.fullName(regexString).fullName.foreach { methodMatch => - regexMatchedFullNames.put(methodMatch, regexString) - } - } - } - - def elements: List[FlowSemantic] = methodToSemantic.values.toList - - def forMethod(fullName: String): Option[FlowSemantic] = regexMatchedFullNames.get(fullName) match { - case Some(matchedFullName) => methodToSemantic.get(matchedFullName) - case None => methodToSemantic.get(fullName) - } - - def serialize: String = { - elements - .sortBy(_.methodFullName) - .map { elem => - s"\"${elem.methodFullName}\" " + elem.mappings - .collect { case FlowMapping(x, y) => s"$x -> $y" } - .mkString(" ") - } - .mkString("\n") - } - -} -case class FlowSemantic(methodFullName: String, mappings: List[FlowPath] = List.empty, regex: Boolean = false) - -object FlowSemantic { - - def from(methodFullName: String, mappings: List[?], regex: Boolean = false): FlowSemantic = { - FlowSemantic( - methodFullName, - mappings.map { - case (src: Int, dst: Int) => FlowMapping(src, dst) - case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) - case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) - case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => FlowMapping(srcIdx, src, dstIdx, dst) - case x: FlowMapping => x - }, - regex - ) - } - -} - -abstract class FlowNode - -/** Collects parameters and return nodes under a common trait. This trait acknowledges their argument index which is - * relevant when a caller wants to coordinate relevant tainted flows through specific arguments and the return flow. - */ -trait ParamOrRetNode extends FlowNode { - - /** Temporary backward compatible idx field. - * - * @return - * the argument index. - */ - def index: Int -} - -/** A parameter where the index of the argument matches the position of the parameter at the callee. The name is used to - * match named arguments if used instead of positional arguments. - * - * @param index - * the position or argument index. - * @param name - * the name of the parameter. - */ -case class ParameterNode(index: Int, name: Option[String] = None) extends ParamOrRetNode - -object ParameterNode { - def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) -} - -/** Represents explicit mappings or special cases. - */ -sealed trait FlowPath - -/** Maps flow between arguments based on how they interact as parameters at the callee. - * - * @param src - * source of the flow. - * @param dst - * destination of the flow. - */ -case class FlowMapping(src: FlowNode, dst: FlowNode) extends FlowPath - -object FlowMapping { - def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) - - def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) - - def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) - - def apply(from: Int, toIdx: Int, to: String): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) - -} - -/** Represents an instance where parameters are not sanitized, may affect the return value, and do not cross-taint. e.g. - * foo(1, 2) = 1 -> 1, 2 -> 2, 1 -> -1, 2 -> -1 - * - * The main benefit is that this works for unbounded parameters e.g. VARARGS. Note this does not taint 0 -> 0. - */ -object PassThroughMapping extends FlowPath - -class Parser() { - - def parse(input: String): List[FlowSemantic] = { - val charStream = CharStreams.fromString(input) - parseCharStream(charStream) - } - - def parseFile(fileName: String): List[FlowSemantic] = { - val charStream = CharStreams.fromFileName(fileName) - parseCharStream(charStream) - } - - private def parseCharStream(charStream: CharStream): List[FlowSemantic] = { - val lexer = new SemanticsLexer(charStream) - val tokenStream = new CommonTokenStream(lexer) - val parser = new SemanticsParser(tokenStream) - val treeWalker = new ParseTreeWalker() - - val tree = parser.taintSemantics() - val listener = new Listener() - treeWalker.walk(listener, tree) - listener.result.toList - } - - implicit class AntlrFlowExtensions(val ctx: MappingContext) { - - def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined - - def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt - - def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) - - def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt - - def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) - - } - - private class Listener extends SemanticsBaseListener { - - val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() - - override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = { - ctx.singleSemantic().asScala.foreach { semantic => - val methodName = semantic.methodName().name().getText - val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) - result.addOne(FlowSemantic(methodName, mappings)) - } - } - - private def ctxToParamMapping(ctx: MappingContext): FlowPath = - if (ctx.isPassThrough) { - PassThroughMapping - } else { - val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) - val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) - - FlowMapping(src, dst) - } - - } - -} diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Semantics.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Semantics.scala new file mode 100644 index 000000000000..1a67ec60c42a --- /dev/null +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/semanticsloader/Semantics.scala @@ -0,0 +1,167 @@ +package io.joern.dataflowengineoss.semanticsloader + +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.Method +import io.shiftleft.semanticcpg.language.* + +trait Semantics { + + /** Useful for `Semantics` that benefit from having some kind of internal state tailored to the current CPG. + */ + def initialize(cpg: Cpg): Unit = {} + + def forMethod(method: Method): Option[FlowSemantic] + + /** Builds a new `Semantics` whose `forMethod` behaviour first lookups in `other` and only if it fails (i.e. returns + * `None`) lookups in the current one. + */ + def after(other: Semantics): Semantics = Semantics.compose(this, other) +} + +object Semantics { + + private def compose(first: Semantics, second: Semantics): Semantics = new Semantics { + + override def initialize(cpg: Cpg): Unit = { + second.initialize(cpg) + first.initialize(cpg) + } + + override def forMethod(method: Method): Option[FlowSemantic] = + second.forMethod(method).orElse { first.forMethod(method) } + } +} + +/** The empty Semantics, whose `forMethod` always fails, i.e. the identity under `Semantics.after`. */ +object NoSemantics extends Semantics { + + override def forMethod(method: Method): Option[FlowSemantic] = None +} + +/** The nil Semantics, whose `forMethod` always succeeds but returns the empty (nil) mapping. */ +object NilSemantics { + + /** Builds a universal nil semantics. Beware this is right-absorbing under `Semantics.after`. */ + def apply(): Semantics = new Semantics { + override def forMethod(method: Method): Option[FlowSemantic] = Some(FlowSemantic(method.fullName, List.empty)) + } + + /** Extensionally builds a nil semantics. */ + def where(methodFullNames: List[String], regex: Boolean = false): Semantics = + FullNameSemantics.fromList(methodFullNames.map { + FlowSemantic(_, List.empty, regex) + }) + + /** Intensionally builds a nil semantics. */ + def where(predicate: Method => Boolean): Semantics = new Semantics { + override def forMethod(method: Method): Option[FlowSemantic] = Option.when(predicate(method)) { + FlowSemantic(method.fullName, List.empty) + } + } +} + +/** Semantics whose mappings are: 0->0, PassThroughMapping. */ +object NoCrossTaintSemantics { + + /** Builds a universal no-cross-taint semantics. Beware this is right-absorbing under `Semantics.after`. */ + def apply(): Semantics = new Semantics { + override def forMethod(method: Method): Option[FlowSemantic] = Some( + FlowSemantic(method.fullName, List(FlowMapping(0, 0), PassThroughMapping)) + ) + } + + /** Extensionally builds a no-cross-taint semantics. */ + def where(methodFullNames: List[String], regex: Boolean = false): Semantics = + FullNameSemantics.fromList(methodFullNames.map { + FlowSemantic(_, List(FlowMapping(0, 0), PassThroughMapping), regex) + }) + + /** Intensionally builds a no-cross-taint semantics. */ + def where(predicate: Method => Boolean): Semantics = new Semantics { + override def forMethod(method: Method): Option[FlowSemantic] = Option.when(predicate(method)) { + FlowSemantic(method.fullName, List(FlowMapping(0, 0), PassThroughMapping)) + } + } +} + +case class FlowSemantic(methodFullName: String, mappings: List[FlowPath] = List.empty, regex: Boolean = false) + +object FlowSemantic { + + def from(methodFullName: String, mappings: List[?], regex: Boolean = false): FlowSemantic = { + FlowSemantic( + methodFullName, + mappings.map { + case (src: Int, dst: Int) => FlowMapping(src, dst) + case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) + case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) + case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => FlowMapping(srcIdx, src, dstIdx, dst) + case x: FlowMapping => x + }, + regex + ) + } + +} + +abstract class FlowNode + +/** Collects parameters and return nodes under a common trait. This trait acknowledges their argument index which is + * relevant when a caller wants to coordinate relevant tainted flows through specific arguments and the return flow. + */ +trait ParamOrRetNode extends FlowNode { + + /** Temporary backward compatible idx field. + * + * @return + * the argument index. + */ + def index: Int +} + +/** A parameter where the index of the argument matches the position of the parameter at the callee. The name is used to + * match named arguments if used instead of positional arguments. + * + * @param index + * the position or argument index. + * @param name + * the name of the parameter. + */ +case class ParameterNode(index: Int, name: Option[String] = None) extends ParamOrRetNode + +object ParameterNode { + def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) +} + +/** Represents explicit mappings or special cases. + */ +sealed trait FlowPath + +/** Maps flow between arguments based on how they interact as parameters at the callee. + * + * @param src + * source of the flow. + * @param dst + * destination of the flow. + */ +case class FlowMapping(src: FlowNode, dst: FlowNode) extends FlowPath + +object FlowMapping { + def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) + + def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) + + def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) + + def apply(from: Int, toIdx: Int, to: String): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) + +} + +/** Represents an instance where parameters are not sanitized, may affect the return value, and do not cross-taint. e.g. + * foo(1, 2) = 1 -> 1, 2 -> 2, 1 -> -1, 2 -> -1 + * + * The main benefit is that this works for unbounded parameters e.g. VARARGS. Note this does not taint 0 -> 0. + */ +object PassThroughMapping extends FlowPath diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/DataFlowSlicing.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/DataFlowSlicing.scala index bfcc4a4ad67b..e1c64e8a5694 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/DataFlowSlicing.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/DataFlowSlicing.scala @@ -2,8 +2,7 @@ package io.joern.dataflowengineoss.slicing import io.joern.dataflowengineoss.language.* import io.joern.x2cpg.utils.ConcurrentTaskUtil -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.PropertyNames +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, Properties} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory @@ -48,9 +47,9 @@ object DataFlowSlicing { val sliceNodesIdSet = sliceNodes.id.toSet // Lazily set up the rest if the filters are satisfied lazy val sliceEdges = sliceNodes - .flatMap(_.outE) - .filter(x => sliceNodesIdSet.contains(x.inNode().id())) - .map { e => SliceEdge(e.outNode().id(), e.inNode().id(), e.label()) } + .inE(EdgeTypes.REACHING_DEF) + .filter(x => sliceNodesIdSet.contains(x.src.id())) + .map { e => SliceEdge(e.src.id(), e.dst.id(), e.label) } .toSet lazy val slice = Option(DataFlowSlice(sliceNodes.map(cfgNodeToSliceNode).toSet, sliceEdges)) @@ -82,8 +81,8 @@ object DataFlowSlicing { case n: TypeRef => sliceNode.copy(name = n.typeFullName, code = n.code) case n => sliceNode.copy( - name = n.property(PropertyNames.NAME, ""), - typeFullName = n.property(PropertyNames.TYPE_FULL_NAME, "") + name = n.propertyOption(Properties.Name).getOrElse(""), + typeFullName = n.propertyOption(Properties.TypeFullName).getOrElse("") ) } } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/UsageSlicing.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/UsageSlicing.scala index 2636af71d20a..a95974c2e57f 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/UsageSlicing.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/UsageSlicing.scala @@ -3,7 +3,7 @@ package io.joern.dataflowengineoss.slicing import io.joern.x2cpg.utils.ConcurrentTaskUtil import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Operators, Properties} import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory @@ -182,15 +182,12 @@ object UsageSlicing { .getOrElse(Iterator.empty) else baseCall.argument) .collect { case n: Expression if n.argumentIndex > 0 => n } - .flatMap { - case _: MethodRef => Option("LAMBDA") + .map { + case _: MethodRef => "LAMBDA" case x => - Option( - x.property( - PropertyNames.TYPE_FULL_NAME, - x.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq("ANY")).headOption - ) - ) + x.propertyOption(Properties.TypeFullName) + .orElse(x.property(Properties.DynamicTypeHintFullName).headOption) + .getOrElse("ANY") } .collect { case x: String => x } .toList diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/package.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/package.scala index e1b705041298..6025f8a83517 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/package.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/slicing/package.scala @@ -1,11 +1,10 @@ package io.joern.dataflowengineoss import better.files.File -import io.shiftleft.codepropertygraph.generated.PropertyNames +import io.shiftleft.codepropertygraph.generated.Properties import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory -import overflowdb.PropertyKey import upickle.default.* import java.util.concurrent.{ExecutorService, Executors} @@ -332,13 +331,15 @@ package object slicing { * extracted. */ def fromNode(node: StoredNode, typeMap: Map[String, String] = Map.empty[String, String]): DefComponent = { - val nodeType = (node.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: node.property( - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - Seq.empty[String] - )).filterNot(_.matches("(ANY|UNKNOWN)")).headOption.getOrElse("ANY") + val typeFullNameProperty = node.propertyOption(Properties.TypeFullName).getOrElse("ANY") + val dynamicTypeHintFullNamesProperty = node.property(Properties.DynamicTypeHintFullName) + val nodeType = (typeFullNameProperty +: dynamicTypeHintFullNamesProperty) + .filterNot(_.matches("(ANY|UNKNOWN)")) + .headOption + .getOrElse("ANY") val typeFullName = typeMap.getOrElse(nodeType, nodeType) - val lineNumber = Option(node.property(new PropertyKey[Integer](PropertyNames.LINE_NUMBER))).map(_.toInt) - val columnNumber = Option(node.property(new PropertyKey[Integer](PropertyNames.COLUMN_NUMBER))).map(_.toInt) + val lineNumber = node.propertyOption(Properties.LineNumber) + val columnNumber = node.propertyOption(Properties.ColumnNumber) node match { case x: MethodParameterIn => ParamDef(x.name, typeFullName, x.index, lineNumber, columnNumber) case x: Call if x.code.startsWith("new ") => diff --git a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsageTests.scala b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsageTests.scala index a0bb2ea6a4bf..9749294a9984 100644 --- a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsageTests.scala +++ b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/queryengine/AccessPathUsageTests.scala @@ -1,13 +1,14 @@ package io.joern.dataflowengineoss.queryengine -import io.shiftleft.OverflowDbTestInstance -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, Operators, Properties} +import flatgraph.{GNode, Graph} +import flatgraph.misc.TestUtils.* +import io.shiftleft.codepropertygraph.generated.PropertyNames +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes, Operators} import io.joern.dataflowengineoss.queryengine.AccessPathUsage.toTrackedBaseAndAccessPathSimple -import io.shiftleft.semanticcpg.accesspath._ -import org.scalatest.matchers.should.Matchers._ +import io.shiftleft.semanticcpg.accesspath.* +import org.scalatest.matchers.should.Matchers.* import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ class AccessPathUsageTests extends AnyWordSpec { @@ -22,35 +23,28 @@ class AccessPathUsageTests extends AnyWordSpec { private val VS = VariablePointerShift private val S = PointerShift - private val g = OverflowDbTestInstance.create + private val g = Cpg.empty.graph - private def genCALL(graph: Graph, op: String, args: Node*): Call = { - val ret = graph + NodeTypes.CALL // (NodeTypes.CALL, Properties.NAME -> op) - ret.setProperty(Properties.Name, op) + private def genCALL(graph: Graph, op: String, args: GNode*): Call = { + val diffGraphBuilder = Cpg.newDiffGraphBuilder + val newCall = NewCall().name(op) + diffGraphBuilder.addNode(newCall) args.reverse.zipWithIndex.foreach { case (arg, idx) => - ret --- EdgeTypes.ARGUMENT --> arg - arg.setProperty(Properties.ArgumentIndex, idx + 1) + diffGraphBuilder.setNodeProperty(arg, PropertyNames.ARGUMENT_INDEX, idx + 1) + diffGraphBuilder.addEdge(newCall, arg, EdgeTypes.ARGUMENT) } - ret.asInstanceOf[Call] + diffGraphBuilder.apply(graph) + newCall.storedRef.get } - private def genLit(graph: Graph, payload: String): Literal = { - val ret = graph + NodeTypes.LITERAL - ret.setProperty(Properties.Code, payload) - ret.asInstanceOf[Literal] - } + private def genLit(graph: Graph, payload: String): Literal = + graph.addNode(NewLiteral().code(payload)) - private def genID(graph: Graph, payload: String): Identifier = { - val ret = graph + NodeTypes.IDENTIFIER - ret.setProperty(Properties.Name, payload) - ret.asInstanceOf[Identifier] - } + private def genID(graph: Graph, payload: String): Identifier = + graph.addNode(NewIdentifier().name(payload)) - private def genFID(graph: Graph, payload: String): FieldIdentifier = { - val ret = graph + NodeTypes.FIELD_IDENTIFIER - ret.setProperty(Properties.CanonicalName, payload) - ret.asInstanceOf[FieldIdentifier] - } + private def genFID(graph: Graph, payload: String): FieldIdentifier = + graph.addNode(NewFieldIdentifier().canonicalName(payload)) private def toTrackedAccessPath(node: StoredNode): AccessPath = toTrackedBaseAndAccessPathSimple(node)._2 diff --git a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/ParserTests.scala b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParserTests.scala similarity index 94% rename from dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/ParserTests.scala rename to dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParserTests.scala index 2c2e08c9fd18..2ed38059a378 100644 --- a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/ParserTests.scala +++ b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/semanticsloader/FullNameSemanticsParserTests.scala @@ -3,10 +3,10 @@ package io.joern.dataflowengineoss.semanticsloader import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -class ParserTests extends AnyWordSpec with Matchers { +class FullNameSemanticsParserTests extends AnyWordSpec with Matchers { class Fixture() { - val parser = new Parser() + val parser = new FullNameSemanticsParser() } "Parser" should { diff --git a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/testfixtures/SemanticTestCpg.scala b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/testfixtures/SemanticTestCpg.scala index f2a68aa35341..63c7061e6a62 100644 --- a/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/testfixtures/SemanticTestCpg.scala +++ b/dataflowengineoss/src/test/scala/io/joern/dataflowengineoss/testfixtures/SemanticTestCpg.scala @@ -3,7 +3,7 @@ package io.joern.dataflowengineoss.testfixtures import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, FullNameSemantics, Semantics} import io.joern.x2cpg.testfixtures.TestCpg import io.shiftleft.semanticcpg.layers.LayerCreatorContext @@ -12,7 +12,7 @@ import io.shiftleft.semanticcpg.layers.LayerCreatorContext trait SemanticTestCpg { this: TestCpg => protected var _withOssDataflow = false - protected var _extraFlows = List.empty[FlowSemantic] + protected var _semantics: Semantics = DefaultSemantics() protected implicit var context: EngineContext = EngineContext() /** Allows one to enable data-flow analysis capabilities to the TestCpg. @@ -22,10 +22,9 @@ trait SemanticTestCpg { this: TestCpg => this } - /** Allows one to add additional semantics to the engine context during PDG creation. - */ - def withExtraFlows(value: List[FlowSemantic] = List.empty): this.type = { - _extraFlows = value + /** Allows one to provide custom semantics to the TestCpg. */ + def withSemantics(value: Semantics): this.type = { + _semantics = value this } @@ -34,10 +33,11 @@ trait SemanticTestCpg { this: TestCpg => */ def applyOssDataFlow(): Unit = { if (_withOssDataflow) { - val context = new LayerCreatorContext(this) - val options = new OssDataFlowOptions(extraFlows = _extraFlows) - new OssDataFlow(options).run(context) - this.context = EngineContext(Semantics.fromList(DefaultSemantics().elements ++ _extraFlows)) + val context = new LayerCreatorContext(this) + val options = new OssDataFlowOptions(semantics = _semantics) + val dataflow = new OssDataFlow(options) + dataflow.run(context) + this.context = EngineContext(dataflow.semantics) } } @@ -45,8 +45,8 @@ trait SemanticTestCpg { this: TestCpg => /** Allows the tests to make use of the data-flow engine and any additional semantics. */ -trait SemanticCpgTestFixture(extraFlows: List[FlowSemantic] = List.empty) { +trait SemanticCpgTestFixture(semantics: Semantics = DefaultSemantics()) { - implicit val context: EngineContext = EngineContext(Semantics.fromList(DefaultSemantics().elements ++ extraFlows)) + implicit val context: EngineContext = EngineContext(semantics) } diff --git a/joern-cli/build.sbt b/joern-cli/build.sbt index 88b17ef355bd..2360137782b5 100644 --- a/joern-cli/build.sbt +++ b/joern-cli/build.sbt @@ -131,4 +131,26 @@ generateScaladocs := { Universal / packageBin / mappings ++= sbt.Path.directory(new File("joern-cli/src/main/resources/scripts")) +lazy val removeModuleInfoFromJars = taskKey[Unit]("remove module-info.class from dependency jars - a hacky workaround for a scala3 compiler bug https://github.com/scala/scala3/issues/20421") +removeModuleInfoFromJars := { + import java.nio.file.{Files, FileSystems} + val logger = streams.value.log + val libDir = (Universal/stagingDirectory).value / "lib" + + // remove all `/module-info.class` from all jars + Files.walk(libDir.toPath) + .filter(_.toString.endsWith(".jar")) + .forEach { jar => + val zipFs = FileSystems.newFileSystem(jar) + zipFs.getRootDirectories.forEach { zipRootDir => + Files.list(zipRootDir).filter(_.toString == "/module-info.class").forEach { moduleInfoClass => + logger.info(s"workaround for scala completion bug: deleting $moduleInfoClass from $jar") + Files.delete(moduleInfoClass) + } + } + zipFs.close() + } +} +removeModuleInfoFromJars := removeModuleInfoFromJars.triggeredBy(Universal/stage).value + maintainer := "fabs@shiftleft.io" diff --git a/joern-cli/frontends/c2cpg/CPP_Features.md b/joern-cli/frontends/c2cpg/CPP_Features.md new file mode 100644 index 000000000000..af6b85bbedca --- /dev/null +++ b/joern-cli/frontends/c2cpg/CPP_Features.md @@ -0,0 +1,51 @@ +# Support For New Language Features + +- For an explanation for each feature you may want to look at https://github.com/AnthonyCalandra/modern-cpp-features/tree/master. +- Table legend: + - `[?]` not yet checked + - `[ ]` not supported at all / can not even be parsed + - `[~]` can be parsed but is not fully represented in the CPG + - `[x]` full support including the CPG representation + +## C++17 Language Features + +| Feature | Supported | +|-------------------------------------------------------------------------|-----------| +| template argument deduction for class templates | [~] | +| declaring non-type template parameters with auto | [~] | +| folding expressions | [x] | +| new rules for auto deduction from braced-init-list | [x] | +| constexpr lambda | [~] | +| lambda capture this by value | [~] | +| inline variables | [x] | +| nested namespaces | [x] | +| structured bindings | [x] | +| selection statements with initializer | [x] | +| constexpr if | [x] | +| utf-8 character literals | [ ] | +| direct-list-initialization of enums | [x] | +| \[\[fallthrough\]\], \[\[nodiscard\]\], \[\[maybe_unused\]\] attributes | [~] | +| \_\_has_include | [~] | +| class template argument deduction | [~] | + +## C++20 Language Features + +| Feature | Supported | +|------------------------------------------------|-----------| +| coroutines | [~] | +| concepts | [ ] | +| three-way comparison | [ ] | +| designated initializers | [~] | +| template syntax for lambdas | [ ] | +| range-based for loop with initializer | [ ] | +| \[\[likely\]\] and \[\[unlikely\]\] attributes | [x] | +| deprecate implicit capture of this | [~] | +| class types in non-type template parameters | [~] | +| constexpr virtual functions | [~] | +| explicit(bool) | [ ] | +| immediate functions | [x] | +| using enum | [ ] | +| lambda capture of parameter pack | [ ] | +| char8_t | [x] | +| constinit | [+] | +| \_\_VA_OPT\_\_ | [ ] | diff --git a/joern-cli/frontends/c2cpg/eclipse-cdt/CCorePlugin.java b/joern-cli/frontends/c2cpg/eclipse-cdt/CCorePlugin.java new file mode 100644 index 000000000000..27edafac9aed --- /dev/null +++ b/joern-cli/frontends/c2cpg/eclipse-cdt/CCorePlugin.java @@ -0,0 +1,102 @@ +/******************************************************************************* + * Copyright (c) 2000, 2020 IBM Corporation and others. + * + * This program and the accompanying materials + * are made available under the terms of the Eclipse Public License 2.0 + * which accompanies this distribution, and is available at + * https://www.eclipse.org/legal/epl-2.0/ + * + * SPDX-License-Identifier: EPL-2.0 + * + * Contributors: + * IBM Corporation - initial API and implementation + * Markus Schorn (Wind River Systems) + * Andrew Ferguson (Symbian) + * Anton Leherbauer (Wind River Systems) + * oyvind.harboe@zylin.com - http://bugs.eclipse.org/250638 + * Jens Elmenthaler - http://bugs.eclipse.org/173458 (camel case completion) + * Sergey Prigogin (Google) + * Alexander Fedorov (ArSysOp) - Bug 561992 + *******************************************************************************/ + +package org.eclipse.cdt.core; + +import org.eclipse.cdt.core.model.CModelException; +import org.eclipse.core.runtime.CoreException; +import org.eclipse.core.runtime.IStatus; +import org.eclipse.core.runtime.Plugin; +import org.eclipse.core.runtime.Status; +import org.osgi.framework.Version; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This is a stripped-down version of the original org.eclipse.cdt.core.CCorePlugin shadowing it and + * providing only the functionality to get it running without all the Eclipse OSGI context. + * Sadly, some parser internal components (e.g., the ambiguity resolving) log via this class. + * Without a proper OSGI instantiation (which we do not have/want in Joern) we are running into + * all kind of exceptions due to non-initialized entities (e.g., the said logging utils). + */ +public class CCorePlugin extends Plugin { + + private static final Logger logger = LoggerFactory.getLogger(CCorePlugin.class); + + public static Version getCDTFeatureVersion() { + return null; + } + + public static IStatus createStatus(String msg) { + return createStatus(msg, null); + } + + public static IStatus createStatus(String msg, Throwable e) { + return new Status(4, "org.eclipse.cdt.core", msg, e); + } + + public static void log(String e) { + log(createStatus(e)); + } + + public static void log(int severity, String msg) { + log(new Status(severity, "org.eclipse.cdt.core", msg)); + } + + public static void logStackTrace(int severity, String msg) { + log(new Status(severity, "org.eclipse.cdt.core", msg, new Exception())); + } + + public static void log(String message, Throwable e) { + Throwable nestedException; + if (e instanceof CModelException && (nestedException = ((CModelException) e).getException()) != null) { + e = nestedException; + } + log(createStatus(message, e)); + } + + public static void log(Throwable e) { + if (e instanceof CoreException) { + IStatus status = ((CoreException) e).getStatus(); + if (status.getException() != null) { + log(status); + } else { + log(createStatus("Error", e)); + } + } else { + String msg = e.getMessage(); + if (msg == null) { + log("Error", e); + } else { + log("Error: " + msg, e); + } + } + } + + public static void log(IStatus status) { + Throwable throwable; + if ((throwable = status.getException()) != null) { + String msg = throwable.getMessage(); + logger.debug(msg, throwable); + } + } + +} diff --git a/joern-cli/frontends/c2cpg/eclipse-cdt/eclipse-cdt-core-publish.sh b/joern-cli/frontends/c2cpg/eclipse-cdt/eclipse-cdt-core-publish.sh index 4015b994ef14..a3b76c5f3520 100755 --- a/joern-cli/frontends/c2cpg/eclipse-cdt/eclipse-cdt-core-publish.sh +++ b/joern-cli/frontends/c2cpg/eclipse-cdt/eclipse-cdt-core-publish.sh @@ -4,54 +4,47 @@ set -o pipefail # this script downloads a cdt-core release from a configurable location # (e.g. an eclipse release mirror, or their jenkins CI) and publishes it to -# sonatype, so that we can promote it to maven central: +# sonatype, so that we can promote it to maven central. +# note: we also swap the original CCorePlugin.java for a simplified one # context: eclipse uses their own repository format called p2 tycho, # but tooling is limited # Some related links: # https://ci.eclipse.org/cdt/job/cdt/job/main # https://ci.eclipse.org/cdt/job/cdt/job/main/353/artifact/releng/org.eclipse.cdt.repo/target/repository/plugins/org.eclipse.cdt.core_8.4.0.202401242025.jar # https://ftp.fau.de/eclipse/tools/cdt/releases/11.4/cdt-11.4.0/plugins/ +# https://github.com/joernio/joern/pull/5178 # -# https://s01.oss.sonatype.org/content/groups/public/io/joern/eclipse-cdt-core/ # https://repo1.maven.org/maven2/io/joern/eclise-cdt-core/ # https://github.com/digimead/sbt-osgi-manager/blob/master/src/main/scala/sbt/osgi/manager/tycho/ResolveP2.scala # adapt for every release -JAR_URL='https://ci.eclipse.org/cdt/job/cdt/job/main/353/artifact/releng/org.eclipse.cdt.repo/target/repository/plugins/org.eclipse.cdt.core_8.4.0.202401242025.jar' -CUSTOM_RELEASE_VERSION='8.4.0.202401242025' - -# adapt when releasing from a different machine: the server id from your local ~/.m2/settings.xml -REPO_ID=sonatype-nexus-staging-joern - - - +JAR_URL='https://ci.eclipse.org/cdt/job/cdt/job/main/452/artifact/releng/org.eclipse.cdt.repo/target/repository/plugins/org.eclipse.cdt.core_8.5.0.202410191453.jar' +CUSTOM_RELEASE_VERSION='8.5.0.202410191453+3' LOCAL_JAR="org.eclipse.cdt.core-$CUSTOM_RELEASE_VERSION.jar" echo "downloading jar from $JAR_URL to $LOCAL_JAR" wget $JAR_URL -O $LOCAL_JAR -# install into local maven repo, mostly to generate a pom -mvn install:install-file -DgroupId=io.joern -DartifactId=eclipse-cdt-core -Dpackaging=jar -Dversion=$CUSTOM_RELEASE_VERSION -Dfile=$LOCAL_JAR -DgeneratePom=true -cp ~/.m2/repository/io/joern/eclipse-cdt-core/$CUSTOM_RELEASE_VERSION/eclipse-cdt-core-$CUSTOM_RELEASE_VERSION.pom pom.xml - -# add pom-extra to pom.xml, to make sonatype happy -head -n -1 pom.xml > pom.tmp -cat pom.tmp pom-extra > pom.xml -rm pom.tmp - -# create empty jar for "sources" - just to make sonatype happy -zip empty.jar LICENSE - -# sign and upload artifacts to sonatype staging -SONATYPE_URL=https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ -mvn gpg:sign-and-deploy-file -Durl=$SONATYPE_URL -DrepositoryId=$REPO_ID -DpomFile=pom.xml -Dclassifier=sources -Dfile=empty.jar -mvn gpg:sign-and-deploy-file -Durl=$SONATYPE_URL -DrepositoryId=$REPO_ID -DpomFile=pom.xml -Dclassifier=javadoc -Dfile=empty.jar -mvn gpg:sign-and-deploy-file -Durl=$SONATYPE_URL -DrepositoryId=$REPO_ID -DpomFile=pom.xml -Dfile=$LOCAL_JAR - -# remove temporary working artifacts -rm $LOCAL_JAR pom.xml empty.jar *.asc - -echo "artifacts are now published to sonatype staging. next step: log into https://s01.oss.sonatype.org -> staging repositories -> select the right one -> close -> release" -echo "you can monitor the maven sync status on https://s01.oss.sonatype.org/content/groups/public/io/joern/eclipse-cdt-core/ and https://repo1.maven.org/maven2/io/joern/eclipse-cdt-core/" -echo "once it's synchronised to maven central (repo1), update the cdt-core version in 'joern/joern-cli/frontends/c2cpg/build.sbt'" - +# create custom-made maven build just for deploying to maven central +rm -rf build +mkdir build +pushd build +sed s/__VERSION__/$CUSTOM_RELEASE_VERSION/ ../pom.xml.template > pom.xml +mkdir -p src/main/resources +unzip -d src/main/resources ../$LOCAL_JAR \ + -x 'META-INF/*.RSA' 'META-INF/*.SF' \ + 'org/eclipse/cdt/core/CCorePlugin*.class' +# passing -x option to exclude some files: +# 1) original signing information, otherwise the jar is unusable because the signature doesn't match +# 2) original CCorePlugin.class because we want to replace it with our simplified version + +# add our custom CCorePlugin.java +mkdir -p src/main/java/org/eclipse/cdt/core +cp ../CCorePlugin.java src/main/java/org/eclipse/cdt/core + +# deploy to sonatype central +mvn javadoc:jar source:jar package gpg:sign deploy +popd + +echo "release is now published to sonatype central and should get promoted to maven central automatically. For more context go to https://central.sonatype.com/publishing/deployments" +echo "once it's synchronised to maven central (https://repo1.maven.org/maven2/io/joern/eclipse-cdt-core/), update the cdt-core version in 'joern/joern-cli/frontends/c2cpg/build.sbt' to $CUSTOM_RELEASE_VERSION" diff --git a/joern-cli/frontends/c2cpg/eclipse-cdt/pom.xml.template b/joern-cli/frontends/c2cpg/eclipse-cdt/pom.xml.template new file mode 100644 index 000000000000..0f1089a6a818 --- /dev/null +++ b/joern-cli/frontends/c2cpg/eclipse-cdt/pom.xml.template @@ -0,0 +1,103 @@ + + + 4.0.0 + io.joern + eclipse-cdt-core + __VERSION__ + + UTF-8 + UTF-8 + + jar + cdt re-release for joern + https://github.com/eclipse-cdt/cdt/ + + + Eclipse Public License 2.0 + + This program and the accompanying materials are made + available under the terms of the Eclipse Public License 2.0 + which accompanies this distribution, and is available at + https://www.eclipse.org/legal/epl-2.0/ + + SPDX-License-Identifier: EPL-2.0 + + + + cdt + + https://github.com/eclipse-cdt/cdt + scm:git@github.com:eclipse-cdt/cdt.git + + + + + max-leuthaeuser + Max Leuthaeuser + https://github.com/max-leuthaeuser + max@qwiet.ai + + + mpollmeier + Michael Pollmeier + http://www.michaelpollmeier.com + michael@michaelpollmeier.com + + + + + org.eclipse.platform + org.eclipse.core.runtime + 3.32.0 + + + org.slf4j + slf4j-api + 2.0.7 + + + + + + org.sonatype.central + central-publishing-maven-plugin + 0.6.0 + true + + sonatype-central-joern + cdt-__VERSION__ + true + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.7 + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.11.1 + + + + + io.joern + eclipse-cdt-core + 8.5.0.202410191453+2 + + + + + + + + diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/C2Cpg.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/C2Cpg.scala index b59c5d0b22a2..c5a4a0ca20a4 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/C2Cpg.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/C2Cpg.scala @@ -1,26 +1,32 @@ package io.joern.c2cpg import io.joern.c2cpg.passes.{AstCreationPass, PreprocessorPass, TypeDeclNodePass} +import io.joern.c2cpg.passes.FunctionDeclNodePass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass} import io.joern.x2cpg.X2Cpg.withNewEmptyCpg import io.joern.x2cpg.X2CpgFrontend import io.joern.x2cpg.utils.Report +import org.slf4j.LoggerFactory import java.util.regex.Pattern +import scala.util.control.NonFatal import scala.util.Try import scala.util.matching.Regex class C2Cpg extends X2CpgFrontend[Config] { - private val report: Report = new Report() + private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) def createCpg(config: Config): Try[Cpg] = { withNewEmptyCpg(config.outputPath, config) { (cpg, config) => + val report = new Report() new MetaDataPass(cpg, Languages.NEWC, config.inputPath).createAndApply() val astCreationPass = new AstCreationPass(cpg, config, report) astCreationPass.createAndApply() + new FunctionDeclNodePass(cpg, astCreationPass.unhandledMethodDeclarations())(config.schemaValidation) + .createAndApply() TypeNodePass.withRegisteredTypes(astCreationPass.typesSeen(), cpg).createAndApply() new TypeDeclNodePass(cpg)(config.schemaValidation).createAndApply() report.print() @@ -28,8 +34,14 @@ class C2Cpg extends X2CpgFrontend[Config] { } def printIfDefsOnly(config: Config): Unit = { - val stmts = new PreprocessorPass(config).run().mkString(",") - println(stmts) + try { + val stmts = new PreprocessorPass(config).run().mkString(",") + println(stmts) + } catch { + case NonFatal(ex) => + logger.error("Failed to print preprocessor statements.", ex) + throw ex + } } } @@ -39,7 +51,7 @@ object C2Cpg { private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) val DefaultIgnoredFolders: List[Regex] = List( - "\\..*".r, + s"(.*[$EscapedFileSeparator])?\\..*".r, s"(.*[$EscapedFileSeparator])?tests?[$EscapedFileSeparator].*".r, s"(.*[$EscapedFileSeparator])?CMakeFiles[$EscapedFileSeparator].*".r ) diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/Main.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/Main.scala index 358a14f3c1c5..61b0675dcbc2 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/Main.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/Main.scala @@ -1,12 +1,12 @@ package io.joern.c2cpg -import io.joern.c2cpg.Frontend._ +import io.joern.c2cpg.Frontend.* import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer +import io.joern.x2cpg.SourceFiles import org.slf4j.LoggerFactory import scopt.OParser -import scala.util.control.NonFatal - final case class Config( includePaths: Set[String] = Set.empty, defines: Set[String] = Set.empty, @@ -17,7 +17,8 @@ final case class Config( includePathsAutoDiscovery: Boolean = false, skipFunctionBodies: Boolean = false, noImageLocations: Boolean = false, - withPreprocessedFiles: Boolean = false + withPreprocessedFiles: Boolean = false, + compilationDatabase: Option[String] = None ) extends X2CpgConfig[Config] { def withIncludePaths(includePaths: Set[String]): Config = { this.copy(includePaths = includePaths).withInheritedFields(this) @@ -58,6 +59,10 @@ final case class Config( def withPreprocessedFiles(value: Boolean): Config = { this.copy(withPreprocessedFiles = value).withInheritedFields(this) } + + def withCompilationDatabase(value: String): Config = { + this.copy(compilationDatabase = Some(value)).withInheritedFields(this) + } } private object Frontend { @@ -94,9 +99,9 @@ private object Frontend { .text("instructs the parser to skip function and method bodies.") .action((_, c) => c.withSkipFunctionBodies(true)), opt[Unit]("no-image-locations") - .text( - "performance optimization, allows the parser not to create image-locations. An image location explains how a name made it into the translation unit. Eg: via macro expansion or preprocessor." - ) + .text("""performance optimization, allows the parser not to create image-locations. + | An image location explains how a name made it into the translation unit. + | E.g., via macro expansion or preprocessor.""".stripMargin) .action((_, c) => c.withNoImageLocations(true)), opt[Unit]("with-preprocessed-files") .text("includes *.i files and gives them priority over their unprocessed origin source files.") @@ -104,27 +109,29 @@ private object Frontend { opt[String]("define") .unbounded() .text("define a name") - .action((d, c) => c.withDefines(c.defines + d)) + .action((d, c) => c.withDefines(c.defines + d)), + opt[String]("compilation-database") + .text("""enables the processing of compilation database files (e.g., compile_commands.json). + | This allows to automatically extract compiler options, source files, and other build information from the specified database + | and ensuring consistency with the build configuration. + | For a cmake based build such a file is generated with the environment variable CMAKE_EXPORT_COMPILE_COMMANDS being present. + | Clang based build are supported e.g., with https://github.com/rizsotto/Bear + | """.stripMargin) + .action((d, c) => c.withCompilationDatabase(SourceFiles.toAbsolutePath(d, c.inputPath))) ) } } -object Main extends X2CpgMain(cmdLineParser, new C2Cpg()) { - - private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) - - def run(config: Config, c2cpg: C2Cpg): Unit = { - if (config.printIfDefsOnly) { - try { - c2cpg.printIfDefsOnly(config) - } catch { - case NonFatal(ex) => - logger.error("Failed to print preprocessor statements.", ex) - throw ex - } - } else { - c2cpg.run(config) +object Main extends X2CpgMain(cmdLineParser, new C2Cpg()) with FrontendHTTPServer[Config, C2Cpg] { + + override protected def newDefaultConfig(): Config = Config() + + override def run(config: Config, c2cpg: C2Cpg): Unit = { + config match { + case c if c.serverMode => startup() + case c if c.printIfDefsOnly => c2cpg.printIfDefsOnly(config) + case _ => c2cpg.run(config) } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreator.scala index 6f4ca5fbd34b..7c9acf3c2ba7 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreator.scala @@ -1,16 +1,15 @@ package io.joern.c2cpg.astcreation import io.joern.c2cpg.Config -import io.joern.x2cpg.datastructures.Scope +import io.joern.c2cpg.parser.HeaderFileFinder import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode, AstNodeBuilder as X2CpgAstNodeBuilder} -import io.joern.x2cpg.datastructures.Global import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import org.eclipse.cdt.core.dom.ast.{IASTNode, IASTTranslationUnit} import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable @@ -19,9 +18,10 @@ import scala.collection.mutable */ class AstCreator( val filename: String, - val global: Global, + val global: CGlobal, val config: Config, val cdtAst: IASTTranslationUnit, + val headerFileFinder: HeaderFileFinder, val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] )(implicit withSchemaValidation: ValidationMode) extends AstCreatorBase(filename) @@ -30,19 +30,21 @@ class AstCreator( with AstForPrimitivesCreator with AstForStatementsCreator with AstForExpressionsCreator + with AstForLambdasCreator with AstNodeBuilder with AstCreatorHelper + with FullNameProvider with MacroHandler with X2CpgAstNodeBuilder[IASTNode, AstCreator] { protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) - protected val scope: Scope[String, (NewNode, String), NewNode] = new Scope() + protected val scope: C2CpgScope = new C2CpgScope() protected val usingDeclarationMappings: mutable.Map[String, String] = mutable.HashMap.empty // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST - // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. + // where the respective nodes are defined. Instead, we put them under the parent TYPE_DECL in which they are defined. // To achieve this we need this extra stack. protected val methodAstParentStack: Stack[NewNode] = new Stack() @@ -82,12 +84,12 @@ class AstCreator( methodAstParentStack.push(fakeGlobalMethod) scope.pushNewScope(fakeGlobalMethod) - val blockNode_ = blockNode(iASTTranslationUnit, Defines.empty, registerType(Defines.anyTypeName)) + val blockNode_ = blockNode(iASTTranslationUnit) val declsAsts = allDecls.flatMap(astsForDeclaration) setArgumentIndices(declsAsts) - val methodReturn = newMethodReturnNode(iASTTranslationUnit, Defines.anyTypeName) + val methodReturn = methodReturnNode(iASTTranslationUnit, Defines.Any) Ast(fakeGlobalTypeDecl).withChild( methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn) ) @@ -100,7 +102,7 @@ class AstCreator( } override protected def lineEnd(node: IASTNode): Option[Int] = { - nullSafeFileLocation(node).map(_.getEndingLineNumber) + nullSafeFileLocationLast(node).map(_.getEndingLineNumber) } protected def column(node: IASTNode): Option[Int] = { diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreatorHelper.scala index a15e0f844637..668cf8f5cb0b 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstCreatorHelper.scala @@ -1,53 +1,50 @@ package io.joern.c2cpg.astcreation -import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall, NewNode} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.x2cpg.{Ast, SourceFiles, ValidationMode} +import io.joern.x2cpg.Ast +import io.joern.x2cpg.SourceFiles +import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.Defines as X2CpgDefines import io.joern.x2cpg.utils.NodeBuilders.newDependencyNode +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.ExpressionNew +import io.shiftleft.codepropertygraph.generated.nodes.NewCall +import io.shiftleft.codepropertygraph.generated.nodes.NewNode import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.utils.IOUtils import org.apache.commons.lang3.StringUtils import org.eclipse.cdt.core.dom.ast.* -import org.eclipse.cdt.core.dom.ast.c.{ICASTArrayDesignator, ICASTDesignatedInitializer, ICASTFieldDesignator} +import org.eclipse.cdt.core.dom.ast.c.ICASTArrayDesignator +import org.eclipse.cdt.core.dom.ast.c.ICASTDesignatedInitializer +import org.eclipse.cdt.core.dom.ast.c.ICASTFieldDesignator import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.c.ICASTKnRFunctionDeclarator -import org.eclipse.cdt.internal.core.dom.parser.c.{CASTArrayRangeDesignator, CASTFunctionDeclarator} +import org.eclipse.cdt.internal.core.dom.parser.c.CASTArrayRangeDesignator +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTArrayRangeDesignator +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFieldReference +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIdExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPMethod +import org.eclipse.cdt.internal.core.dom.parser.cpp.ICPPEvaluation import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalBinding -import org.eclipse.cdt.internal.core.dom.parser.cpp.{ - CPPASTArrayRangeDesignator, - CPPASTFieldReference, - CPPASTFunctionDeclarator, - CPPASTIdExpression, - CPPFunction, - CPPMethod, - ICPPEvaluation -} import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalMemberAccess +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFoldExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalBinary +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalFoldExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTEqualsInitializer +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPFunction +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPVariable import org.eclipse.cdt.internal.core.model.ASTStringUtil -import java.nio.file.{Path, Paths} +import java.nio.file.Path +import java.nio.file.Paths import scala.annotation.nowarn import scala.collection.mutable import scala.util.Try -object AstCreatorHelper { - - implicit class OptionSafeAst(val ast: Ast) extends AnyVal { - def withArgEdge(src: NewNode, dst: Option[NewNode]): Ast = dst match { - case Some(value) => ast.withArgEdge(src, value) - case None => ast - } - } -} - trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - import io.joern.c2cpg.astcreation.AstCreatorHelper.* - private var usedVariablePostfix: Int = 0 - private val IncludeKeyword = "include" - protected def isIncludedNode(node: IASTNode): Boolean = fileName(node) != filename protected def uniqueName(target: String, name: String, fullName: String): (String, String) = { @@ -80,9 +77,11 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As protected def nullSafeFileLocation(node: IASTNode): Option[IASTFileLocation] = Option(cdtAst.flattenLocationsToFile(node.getNodeLocations)).map(_.asFileLocation()) + protected def nullSafeFileLocationLast(node: IASTNode): Option[IASTFileLocation] = + Option(cdtAst.flattenLocationsToFile(node.getNodeLocations.lastOption.toArray)).map(_.asFileLocation()) protected def fileName(node: IASTNode): String = { - val path = nullSafeFileLocation(node).map(_.getFileName).getOrElse(filename) + val path = Try(node.getContainingFilename).getOrElse(filename) SourceFiles.toRelativePath(path, config.inputPath) } @@ -105,12 +104,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As fixedTypeName } + protected def registerMethodDeclaration(fullName: String, methodInfo: CGlobal.MethodInfo): Unit = { + global.methodDeclarations.putIfAbsent(fullName, methodInfo) + } + + protected def registerMethodDefinition(fullName: String): Unit = { + global.methodDefinitions.putIfAbsent(fullName, true) + } + // Sadly, there is no predefined List / Enum of this within Eclipse CDT: - private val reservedTypeKeywords: List[String] = + private val ReservedKeywordsAtTypes: List[String] = List( "const", "static", - "volatile", "restrict", "extern", "typedef", @@ -124,114 +130,212 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As "class" ) - protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = { - val tpe = - if (stripKeywords) { - reservedTypeKeywords.foldLeft(rawType) { (cur, repl) => - if (cur.contains(s"$repl ")) { - dereferenceTypeFullName(cur.replace(s"$repl ", "")) - } else { - cur - } + private val KeywordsAtTypesToKeep: List[String] = List("unsigned", "volatile") + + protected def cleanType(rawType: String): String = { + if (rawType == Defines.Any) return rawType + val normalizedTpe = StringUtils.normalizeSpace(rawType.stripSuffix(" ()")) + val tpe = ReservedKeywordsAtTypes.foldLeft(normalizedTpe) { (cur, repl) => + if (cur.startsWith(s"$repl ") || cur.contains(s" $repl ")) { + cur.replace(s" $repl ", " ").stripPrefix(s"$repl ") + } else cur + } + replaceWhitespaceAfterKeyword(tpe) match { + case "" => Defines.Any + case t if t.startsWith("[") && t.endsWith("]") => Defines.Array + case t if isThisLambdaCapture(t) || t.contains("->") => Defines.Function + case t if t.contains("?") => Defines.Any + case t if t.contains("#") => Defines.Any + case t if t.contains("::{") || t.contains("}::") => Defines.Any + case t if t.contains("{") || t.contains("}") => Defines.Any + case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => Defines.Any + case t if t.contains("( ") => fixQualifiedName(t.substring(0, t.indexOf("( "))) + case someType => fixQualifiedName(someType) + } + } + + private def replaceWhitespaceAfterKeyword(tpe: String): String = { + if (KeywordsAtTypesToKeep.exists(k => tpe.startsWith(s"$k ") || tpe.contains(s" $k "))) { + KeywordsAtTypesToKeep.foldLeft(tpe) { (cur, repl) => + val prefixStartsWith = s"$repl " + val prefixContains = s" $repl " + if (cur.startsWith(prefixStartsWith)) { + prefixStartsWith + replaceWhitespaceAfterKeyword(cur.substring(prefixStartsWith.length)) + } else if (cur.contains(prefixContains)) { + val front = tpe.substring(0, tpe.indexOf(prefixContains)) + val back = tpe.substring(tpe.indexOf(prefixContains) + prefixContains.length) + s"${replaceWhitespaceAfterKeyword(front)}$prefixContains${replaceWhitespaceAfterKeyword(back)}" + } else { + cur } - } else { - rawType } - StringUtils.normalizeSpace(tpe) match { - case "" => Defines.anyTypeName - case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => Defines.anyTypeName - case t if t.contains(" ->") && t.contains("}::") => - fixQualifiedName(t.substring(t.indexOf("}::") + 3, t.indexOf(" ->"))) - case t if t.contains(" ->") => - fixQualifiedName(t.substring(0, t.indexOf(" ->"))) - case t if t.contains("( ") => - fixQualifiedName(t.substring(0, t.indexOf("( "))) - case t if t.contains("?") => Defines.anyTypeName - case t if t.contains("#") => Defines.anyTypeName - case t if t.contains("{") && t.contains("}") => - val anonType = - s"${uniqueName("type", "", "")._1}${t.substring(0, t.indexOf("{"))}${t.substring(t.indexOf("}") + 1)}" - anonType.replace(" ", "") - case t if t.startsWith("[") && t.endsWith("]") => Defines.anyTypeName - case t if t.contains(Defines.qualifiedNameSeparator) => fixQualifiedName(t) - case t if t.startsWith("unsigned ") => "unsigned " + t.substring(9).replace(" ", "") - case t if t.contains("[") && t.contains("]") => t.replace(" ", "") - case t if t.contains("*") => t.replace(" ", "") - case someType => someType + } else { + tpe.replace(" ", "") } } - private def safeGetEvaluation(expr: ICPPASTExpression): Option[ICPPEvaluation] = { + private def isThisLambdaCapture(tpe: String): Boolean = { + tpe.startsWith("[*this]") || tpe.startsWith("[this]") || (tpe.startsWith("[") && tpe.contains("this]")) + } + + protected def safeGetEvaluation(expr: ICPPASTExpression): Option[ICPPEvaluation] = { // In case of unresolved includes etc. this may fail throwing an unrecoverable exception Try(expr.getEvaluation).toOption } + protected def safeGetBinding(name: IASTName): Option[IBinding] = { + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + Try(name.resolveBinding()).toOption + } + + protected def safeGetBinding(idExpression: IASTIdExpression): Option[IBinding] = { + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + safeGetBinding(idExpression.getName).collect { + case binding: IBinding if !binding.isInstanceOf[IProblemBinding] => binding + } + } + + protected def safeGetBinding(spec: IASTNamedTypeSpecifier): Option[IBinding] = { + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + safeGetBinding(spec.getName).collect { + case binding: IBinding if !binding.isInstanceOf[IProblemBinding] => binding + } + } + + protected def safeGetType(tpe: IType): String = { + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + Try(ASTTypeUtil.getType(tpe)).getOrElse(Defines.Any) + } + + private def safeGetNodeType(node: IASTNode): String = { + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + Try(ASTTypeUtil.getNodeType(node)).getOrElse(Defines.Any) + } + + private def typeForCPPASTFieldReference(f: CPPASTFieldReference): String = { + safeGetEvaluation(f.getFieldOwner) match { + case Some(evaluation: EvalBinding) => cleanType(evaluation.getType.toString) + case _ => cleanType(safeGetType(f.getFieldOwner.getExpressionType)) + } + } + + private def typeForCPPASTFoldExpression(f: CPPASTFoldExpression): String = { + safeGetEvaluation(f) match { + case Some(evaluation: EvalFoldExpression) => + Try(evaluation.getValue.getEvaluation).toOption match { + case Some(value: EvalBinary) => + val s = value.toString + cleanType(s.substring(0, s.indexOf(": "))) + case Some(value: EvalBinding) if value.getType.isInstanceOf[ICPPParameterPackType] => + val s = value.getType.asInstanceOf[ICPPParameterPackType].getType.toString + cleanType(s) + case _ => Defines.Any + } + case _ => Defines.Any + } + } + @nowarn - protected def typeFor(node: IASTNode, stripKeywords: Boolean = true): String = { + private def typeForIASTArrayDeclarator(a: IASTArrayDeclarator): String = { import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature - node match { - case f: CPPASTFieldReference => - safeGetEvaluation(f.getFieldOwner) match { - case Some(evaluation: EvalBinding) => cleanType(evaluation.getType.toString, stripKeywords) - case _ => cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) - } - case f: IASTFieldReference => - cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).startsWith("? ") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val arr = ASTTypeUtil.getNodeType(a).replace("? ", "") - s"$tpe$arr" - case a: IASTArrayDeclarator - if ASTTypeUtil.getNodeType(a).contains("} ") || ASTTypeUtil.getNodeType(a).contains(" [") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val arr = a.getArrayModifiers.map { - case m if m.getConstantExpression != null => s"[${nodeSignature(m.getConstantExpression)}]" - case _ if a.getInitializer != null => - a.getInitializer match { - case l: IASTInitializerList => s"[${l.getSize}]" - case _ => "[]" - } - case _ => "[]" - }.mkString - s"$tpe$arr" - case s: CPPASTIdExpression => - safeGetEvaluation(s) match { - case Some(evaluation: EvalMemberAccess) => - cleanType(evaluation.getOwnerType.toString, stripKeywords) - case Some(evalBinding: EvalBinding) => - evalBinding.getBinding match { - case m: CPPMethod => cleanType(fullName(m.getDefinition)) - case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) - } - case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) + if (safeGetNodeType(a).startsWith("? ")) { + val tpe = getNodeSignature(a).replace("[]", "").strip() + val arr = safeGetNodeType(a).replace("? ", "") + s"$tpe$arr" + } else if (safeGetNodeType(a).contains("} ") || safeGetNodeType(a).contains(" [")) { + val tpe = getNodeSignature(a).replace("[]", "").strip() + val arr = a.getArrayModifiers.map { + case m if m.getConstantExpression != null => s"[${nodeSignature(m.getConstantExpression)}]" + case _ if a.getInitializer != null => + a.getInitializer match { + case l: IASTInitializerList => s"[${l.getSize}]" + case _ => "[]" + } + case _ => "[]" + }.mkString + s"$tpe$arr" + } else { + cleanType(safeGetNodeType(a)) + } + } + + private def typeForCPPASTIdExpression(s: CPPASTIdExpression): String = { + safeGetEvaluation(s) match { + case Some(evaluation: EvalMemberAccess) => + val deref = if (evaluation.isPointerDeref) "*" else "" + cleanType(evaluation.getOwnerType.toString + deref) + case Some(evalBinding: EvalBinding) => + evalBinding.getBinding match { + case m: CPPMethod => cleanType(safeGetNodeType(m.getPrimaryDeclaration)) + case f: CPPFunction => cleanType(safeGetNodeType(f.getDefinition)) + case v: CPPVariable => cleanType(v.getType.toString) + case _ => cleanType(safeGetNodeType(s)) } - case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => - cleanType(ASTTypeUtil.getNodeType(node), stripKeywords) - case s: IASTNamedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTCompositeTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTEnumerationSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTElaboratedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case l: IASTLiteralExpression => - cleanType(ASTTypeUtil.getType(l.getExpressionType)) - case e: IASTExpression => - cleanType(ASTTypeUtil.getNodeType(e), stripKeywords) - case c: ICPPASTConstructorInitializer if c.getParent.isInstanceOf[ICPPASTConstructorChainInitializer] => - cleanType( - fullName(c.getParent.asInstanceOf[ICPPASTConstructorChainInitializer].getMemberInitializerId), - stripKeywords - ) + case _ => cleanType(safeGetNodeType(s)) + } + } + + @nowarn + private def typeForICPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): String = { + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + c.getParent match { + case initializer: ICPPASTConstructorChainInitializer => + val initIdFullName = fullName(initializer.getMemberInitializerId) + cleanType(initIdFullName) case _ => - cleanType(getNodeSignature(node), stripKeywords) + cleanType(getNodeSignature(c)) + } + } + + private def typeForCPPASTEqualsInitializer(c: CPPASTEqualsInitializer): String = { + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + c.getInitializerClause match { + case initializer: ICPPASTFunctionCallExpression + if initializer.getFunctionNameExpression.isInstanceOf[CPPASTIdExpression] => + val name = initializer.getFunctionNameExpression.asInstanceOf[CPPASTIdExpression] + typeForCPPASTIdExpression(name) + case _ => + cleanType(getNodeSignature(c)) + } + } + + @nowarn + protected def typeFor(node: IASTNode): String = { + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + node match { + case f: CPPASTFoldExpression => typeForCPPASTFoldExpression(f) + case f: CPPASTFieldReference => typeForCPPASTFieldReference(f) + case s: CPPASTIdExpression => typeForCPPASTIdExpression(s) + case s: ICPPASTNamedTypeSpecifier => typeForCPPAstNamedTypeSpecifier(s) + case a: IASTArrayDeclarator => typeForIASTArrayDeclarator(a) + case c: ICPPASTConstructorInitializer => typeForICPPASTConstructorInitializer(c) + case c: CPPASTEqualsInitializer => typeForCPPASTEqualsInitializer(c) + case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => cleanType(safeGetNodeType(node)) + case f: IASTFieldReference => cleanType(safeGetType(f.getFieldOwner.getExpressionType)) + case s: IASTNamedTypeSpecifier => cleanType(ASTStringUtil.getReturnTypeString(s, null)) + case s: IASTCompositeTypeSpecifier => cleanType(ASTStringUtil.getReturnTypeString(s, null)) + case s: IASTEnumerationSpecifier => cleanType(ASTStringUtil.getReturnTypeString(s, null)) + case s: IASTElaboratedTypeSpecifier => cleanType(ASTStringUtil.getReturnTypeString(s, null)) + case l: IASTLiteralExpression => cleanType(safeGetType(l.getExpressionType)) + case e: IASTExpression => cleanType(safeGetNodeType(e)) + case _ => cleanType(getNodeSignature(node)) + } + } + + private def typeForCPPAstNamedTypeSpecifier(s: ICPPASTNamedTypeSpecifier): String = { + val tpe = safeGetBinding(s) match { + case Some(spec: ICPPSpecialization) => spec.toString + case Some(n: ICPPBinding) => n.getQualifiedName.mkString(".") + case Some(other: IBinding) => other.toString + case _ if s.getName != null => ASTStringUtil.getQualifiedName(s.getName) + case _ => s.getRawSignature } + cleanType(tpe) } private def notHandledText(node: IASTNode): String = s"""Node '${node.getClass.getSimpleName}' not handled yet! - | Code: '${node.getRawSignature}' + | Code: '${shortenCode(node.getRawSignature)}' | File: '$filename' | Line: ${line(node).getOrElse(-1)} | """.stripMargin @@ -258,6 +362,9 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As r } + protected def nullSafeAst(node: IASTInitializer): Ast = + Option(node).map(astForNode).getOrElse(Ast()) + protected def nullSafeAst(node: IASTExpression): Ast = Option(node).map(astForNode).getOrElse(Ast()) @@ -268,152 +375,17 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As Option(node).map(astsForStatement(_, argIndex)).getOrElse(Seq.empty) } - protected def dereferenceTypeFullName(fullName: String): String = - fullName.replace("*", "") - - protected def fixQualifiedName(name: String): String = - name.stripPrefix(Defines.qualifiedNameSeparator).replace(Defines.qualifiedNameSeparator, ".") - - protected def isQualifiedName(name: String): Boolean = - name.startsWith(Defines.qualifiedNameSeparator) - - protected def lastNameOfQualifiedName(name: String): String = { - val cleanedName = if (name.contains("<") && name.contains(">")) { - name.substring(0, name.indexOf("<")) - } else { - name - } - cleanedName.split(Defines.qualifiedNameSeparator).lastOption.getOrElse(cleanedName) - } - protected def functionTypeToSignature(typ: IFunctionType): String = { - val returnType = ASTTypeUtil.getType(typ.getReturnType) - val parameterTypes = typ.getParameterTypes.map(ASTTypeUtil.getType) - s"$returnType(${parameterTypes.mkString(",")})" + val returnType = cleanType(safeGetType(typ.getReturnType)) + val parameterTypes = typ.getParameterTypes.map(t => cleanType(safeGetType(t))) + StringUtils.normalizeSpace(s"$returnType(${parameterTypes.mkString(",")})") } - protected def fullName(node: IASTNode): String = { - node match { - case declarator: CPPASTFunctionDeclarator => - declarator.getName.resolveBinding() match { - case function: ICPPFunction => - val fullNameNoSig = function.getQualifiedName.mkString(".") - val fn = - if (function.isExternC) { - function.getName - } else { - s"$fullNameNoSig:${functionTypeToSignature(function.getType)}" - } - return fn - case field: ICPPField => - case _: IProblemBinding => - return "" - } - case declarator: CASTFunctionDeclarator => - val fn = declarator.getName.toString - return fn - case definition: ICPPASTFunctionDefinition => - return fullName(definition.getDeclarator) - case x => - } - - val qualifiedName: String = node match { - case d: CPPASTIdExpression => - safeGetEvaluation(d) match { - case Some(evalBinding: EvalBinding) => - evalBinding.getBinding match { - case f: CPPFunction if f.getDeclarations != null => - f.getDeclarations.headOption.map(n => s"${fullName(n)}").getOrElse(f.getName) - case f: CPPFunction if f.getDefinition != null => - s"${fullName(f.getDefinition)}" - case other => - other.getName - } - case _ => ASTStringUtil.getSimpleName(d.getName) - } - - case alias: ICPPASTNamespaceAlias => alias.getMappingName.toString - case namespace: ICPPASTNamespaceDefinition if ASTStringUtil.getSimpleName(namespace.getName).nonEmpty => - s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" - case namespace: ICPPASTNamespaceDefinition if ASTStringUtil.getSimpleName(namespace.getName).isEmpty => - s"${fullName(namespace.getParent)}.${uniqueName("namespace", "", "")._1}" - case compType: IASTCompositeTypeSpecifier if ASTStringUtil.getSimpleName(compType.getName).nonEmpty => - s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" - case compType: IASTCompositeTypeSpecifier if ASTStringUtil.getSimpleName(compType.getName).isEmpty => - val name = compType.getParent match { - case decl: IASTSimpleDeclaration => - decl.getDeclarators.headOption - .map(n => ASTStringUtil.getSimpleName(n.getName)) - .getOrElse(uniqueName("composite_type", "", "")._1) - case _ => uniqueName("composite_type", "", "")._1 - } - s"${fullName(compType.getParent)}.$name" - case enumSpecifier: IASTEnumerationSpecifier => - s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" - case f: ICPPASTLambdaExpression => - s"${fullName(f.getParent)}." - case f: IASTFunctionDefinition if f.getDeclarator != null => - s"${fullName(f.getParent)}.${ASTStringUtil.getQualifiedName(f.getDeclarator.getName)}" - case f: IASTFunctionDefinition => - s"${fullName(f.getParent)}.${shortName(f)}" - case e: IASTElaboratedTypeSpecifier => - s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" - case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) - case _: IASTTranslationUnit => "" - case u: IASTUnaryExpression => code(u.getOperand) - case x: ICPPASTQualifiedName => ASTStringUtil.getQualifiedName(x) - case other if other != null && other.getParent != null => fullName(other.getParent) - case other if other != null => notHandledYet(other); "" - case null => "" + private def pointersAsString(spec: IASTDeclSpecifier, parentDecl: IASTDeclarator): String = { + val tpe = typeFor(spec) match { + case Defines.Auto => typeFor(parentDecl) + case t => t } - fixQualifiedName(qualifiedName).stripPrefix(".") - } - - protected def shortName(node: IASTNode): String = { - val name = node match { - case d: IASTDeclarator if ASTStringUtil.getSimpleName(d.getName).isEmpty && d.getNestedDeclarator != null => - shortName(d.getNestedDeclarator) - case d: IASTDeclarator => ASTStringUtil.getSimpleName(d.getName) - case f: ICPPASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: ICPPASTFunctionDefinition => lastNameOfQualifiedName(ASTStringUtil.getSimpleName(f.getDeclarator.getName)) - case f: IASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: IASTFunctionDefinition => ASTStringUtil.getSimpleName(f.getDeclarator.getName) - case d: CPPASTIdExpression => - safeGetEvaluation(d) match { - case Some(evalBinding: EvalBinding) => - evalBinding.getBinding match { - case f: CPPFunction if f.getDeclarations != null => - f.getDeclarations.headOption.map(n => ASTStringUtil.getSimpleName(n.getName)).getOrElse(f.getName) - case f: CPPFunction if f.getDefinition != null => - ASTStringUtil.getSimpleName(f.getDefinition.getName) - case other => - other.getName - } - case _ => lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) - } - case d: IASTIdExpression => lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) - case u: IASTUnaryExpression => shortName(u.getOperand) - case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) - case s: IASTSimpleDeclSpecifier => s.getRawSignature - case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) - case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) - case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) - case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case other => notHandledYet(other); "" - } - name - } - - private def pointersAsString(spec: IASTDeclSpecifier, parentDecl: IASTDeclarator, stripKeywords: Boolean): String = { - val tpe = typeFor(spec, stripKeywords) val pointers = parentDecl.getPointerOperators val arr = parentDecl match { case p: IASTArrayDeclarator => p.getArrayModifiers.toList.map(_.getRawSignature).mkString @@ -421,7 +393,13 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } if (pointers.isEmpty) { s"$tpe$arr" } else { - val refs = "*" * (pointers.length - pointers.count(_.isInstanceOf[ICPPASTReferenceOperator])) + val refs = pointers + .map { + case r: ICPPASTReferenceOperator if r.isRValueReference => "&&" + case _: ICPPASTReferenceOperator => "&" + case _: IASTPointer => "*" + } + .mkString("") s"$tpe$arr$refs".strip() } } @@ -430,7 +408,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val allIncludes = iASTTranslationUnit.getIncludeDirectives.toList.filterNot(isIncludedNode) allIncludes.map { include => val name = include.getName.toString - val _dependencyNode = newDependencyNode(name, name, IncludeKeyword) + val _dependencyNode = newDependencyNode(name, name, "include") val importNode = newImportNode(code(include), name, name, include) diffGraph.addNode(_dependencyNode) diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) @@ -447,19 +425,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } private def astForDecltypeSpecifier(decl: ICPPASTDecltypeSpecifier): Ast = { - val op = ".typeOf" - val cpgUnary = callNode(decl, code(decl), op, op, DispatchTypes.STATIC_DISPATCH) + val op = Defines.OperatorTypeOf + val cpgUnary = callNode(decl, code(decl), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val operand = nullSafeAst(decl.getDecltypeExpression) callAst(cpgUnary, List(operand)) } private def astForCASTDesignatedInitializer(d: ICASTDesignatedInitializer): Ast = { - val node = blockNode(d, Defines.empty, Defines.voidTypeName) + val node = blockNode(d, Defines.Empty, Defines.Void) scope.pushNewScope(node) val op = Operators.assignment val calls = withIndex(d.getDesignators) { (des, o) => val callNode_ = - callNode(d, code(d), op, op, DispatchTypes.STATIC_DISPATCH) + callNode(d, code(d), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) .argumentIndex(o) val left = astForNode(des) val right = astForNode(d.getOperand) @@ -470,12 +448,12 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } private def astForCPPASTDesignatedInitializer(d: ICPPASTDesignatedInitializer): Ast = { - val node = blockNode(d, Defines.empty, Defines.voidTypeName) + val node = blockNode(d, Defines.Empty, Defines.Void) scope.pushNewScope(node) val op = Operators.assignment val calls = withIndex(d.getDesignators) { (des, o) => val callNode_ = - callNode(d, code(d), op, op, DispatchTypes.STATIC_DISPATCH) + callNode(d, code(d), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) .argumentIndex(o) val left = astForNode(des) val right = astForNode(d.getOperand) @@ -486,16 +464,15 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } private def astForCPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): Ast = { - val name = ".constructorInitializer" - val callNode_ = - callNode(c, code(c), name, name, DispatchTypes.STATIC_DISPATCH) - val args = c.getArguments.toList.map(a => astForNode(a)) + val name = Defines.OperatorConstructorInitializer + val callNode_ = callNode(c, code(c), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) + val args = c.getArguments.toList.map(a => astForNode(a)) callAst(callNode_, args) } private def astForCASTArrayRangeDesignator(des: CASTArrayRangeDesignator): Ast = { val op = Operators.arrayInitializer - val callNode_ = callNode(des, code(des), op, op, DispatchTypes.STATIC_DISPATCH) + val callNode_ = callNode(des, code(des), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val floorAst = nullSafeAst(des.getRangeFloor) val ceilingAst = nullSafeAst(des.getRangeCeiling) callAst(callNode_, List(floorAst, ceilingAst)) @@ -503,7 +480,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As private def astForCPPASTArrayRangeDesignator(des: CPPASTArrayRangeDesignator): Ast = { val op = Operators.arrayInitializer - val callNode_ = callNode(des, code(des), op, op, DispatchTypes.STATIC_DISPATCH) + val callNode_ = callNode(des, code(des), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val floorAst = nullSafeAst(des.getRangeFloor) val ceilingAst = nullSafeAst(des.getRangeCeiling) callAst(callNode_, List(floorAst, ceilingAst)) @@ -517,6 +494,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As case l: IASTInitializerList => astForInitializerList(l) case c: ICPPASTConstructorInitializer => astForCPPASTConstructorInitializer(c) case d: ICASTDesignatedInitializer => astForCASTDesignatedInitializer(d) + case d: IASTEqualsInitializer => astForNode(d.getInitializerClause) case d: ICPPASTDesignatedInitializer => astForCPPASTDesignatedInitializer(d) case d: CASTArrayRangeDesignator => astForCASTArrayRangeDesignator(d) case d: CPPASTArrayRangeDesignator => astForCPPASTArrayRangeDesignator(d) @@ -530,49 +508,49 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } - protected def typeForDeclSpecifier(spec: IASTNode, stripKeywords: Boolean = true, index: Int = 0): String = { + protected def typeForDeclSpecifier(spec: IASTNode, index: Int = 0): String = { val tpe = spec match { case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTFunctionDefinition] => val parentDecl = s.getParent.asInstanceOf[IASTFunctionDefinition].getDeclarator ASTStringUtil.getReturnTypeString(s, parentDecl) case s: IASTSimpleDeclaration if s.getParent.isInstanceOf[ICASTKnRFunctionDeclarator] => val decl = s.getDeclarators.toList(index) - pointersAsString(s.getDeclSpecifier, decl, stripKeywords) + pointersAsString(s.getDeclSpecifier, decl) case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTSimpleDeclSpecifier => ASTStringUtil.getReturnTypeString(s, null) case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) case s: IASTCompositeTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) case s: IASTEnumerationSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(s.getName) case s: IASTElaboratedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTElaboratedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) + pointersAsString(s, parentDecl) case s: IASTElaboratedTypeSpecifier => ASTStringUtil.getSignatureString(s, null) // TODO: handle other types of IASTDeclSpecifier - case _ => Defines.anyTypeName + case _ => Defines.Any } - if (tpe.isEmpty) Defines.anyTypeName else tpe + if (tpe.isEmpty) Defines.Any else tpe } // We use our own call ast creation function since the version in x2cpg treats diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForExpressionsCreator.scala index 3d999c23d4cd..c3d31be2bf2e 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForExpressionsCreator.scala @@ -1,31 +1,25 @@ package io.joern.c2cpg.astcreation -import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewIdentifier, NewMethodRef} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.Defines as X2CpgDefines +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} +import org.apache.commons.lang3.StringUtils import org.eclipse.cdt.core.dom.ast import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.IGNUASTCompoundStatementExpression -import org.eclipse.cdt.core.model.IMethod -import org.eclipse.cdt.internal.core.dom.parser.c.{ - CASTFieldReference, - CASTFunctionCallExpression, - CASTIdExpression, - CBasicType, - CFunctionType, - CPointerType -} -import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.{EvalBinding, EvalFunctionCall} -import org.eclipse.cdt.internal.core.dom.parser.cpp.{ - CPPASTIdExpression, - CPPASTQualifiedName, - CPPClosureType, - CPPField, - CPPFunction, - CPPFunctionType -} +import org.eclipse.cdt.internal.core.dom.parser.c.CASTFunctionCallExpression +import org.eclipse.cdt.internal.core.dom.parser.c.CASTIdExpression +import org.eclipse.cdt.internal.core.dom.parser.c.CFunctionType +import org.eclipse.cdt.internal.core.dom.parser.c.CPointerType +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIdExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPClosureType +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalFunctionCall +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFoldExpression + +import scala.util.Try trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -62,67 +56,55 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case IASTBinaryExpression.op_notequals => Operators.notEquals case IASTBinaryExpression.op_pmdot => Operators.indirectFieldAccess case IASTBinaryExpression.op_pmarrow => Operators.indirectFieldAccess - case IASTBinaryExpression.op_max => ".max" - case IASTBinaryExpression.op_min => ".min" - case IASTBinaryExpression.op_ellipses => ".op_ellipses" - case _ => ".unknown" + case IASTBinaryExpression.op_max => Defines.OperatorMax + case IASTBinaryExpression.op_min => Defines.OperatorMin + case IASTBinaryExpression.op_ellipses => Defines.OperatorEllipses + case _ => Defines.OperatorUnknown } - val callNode_ = callNode(bin, code(bin), op, op, DispatchTypes.STATIC_DISPATCH) + val callNode_ = callNode(bin, code(bin), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val left = nullSafeAst(bin.getOperand1) val right = nullSafeAst(bin.getOperand2) callAst(callNode_, List(left, right)) } private def astForExpressionList(exprList: IASTExpressionList): Ast = { - val name = ".expressionList" + val name = Defines.OperatorExpressionList val callNode_ = - callNode(exprList, code(exprList), name, name, DispatchTypes.STATIC_DISPATCH) + callNode(exprList, code(exprList), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val childAsts = exprList.getExpressions.map(nullSafeAst) callAst(callNode_, childAsts.toIndexedSeq) } private def astForCppCallExpression(call: ICPPASTFunctionCallExpression): Ast = { val functionNameExpr = call.getFunctionNameExpression - val typ = functionNameExpr.getExpressionType - typ match { - case pointerType: IPointerType => - createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) - case functionType: ICPPFunctionType => + Try(functionNameExpr.getExpressionType).toOption match { + case Some(_: IPointerType) => createPointerCallAst(call, cleanType(safeGetType(call.getExpressionType))) + case Some(functionType: ICPPFunctionType) => functionNameExpr match { - case idExpr: CPPASTIdExpression => - val function = idExpr.getName.getBinding.asInstanceOf[ICPPFunction] - val name = idExpr.getName.getLastName.toString - val signature = - if (function.isExternC) { - "" - } else { - functionTypeToSignature(functionType) - } - - val fullName = - if (function.isExternC) { - name - } else { - val fullNameNoSig = function.getQualifiedName.mkString(".") - s"$fullNameNoSig:$signature" - } - - val dispatchType = DispatchTypes.STATIC_DISPATCH - + case idExpr: CPPASTIdExpression if safeGetBinding(idExpr).exists(_.isInstanceOf[ICPPFunction]) => + val function = idExpr.getName.getBinding.asInstanceOf[ICPPFunction] + val name = idExpr.getName.getLastName.toString + val signature = if function.isExternC then "" else functionTypeToSignature(functionType) + val fullName = if (function.isExternC) { + StringUtils.normalizeSpace(name) + } else { + val fullNameNoSig = StringUtils.normalizeSpace(function.getQualifiedName.mkString(".")) + s"$fullNameNoSig:$signature" + } val callCpgNode = callNode( call, code(call), name, fullName, - dispatchType, + DispatchTypes.STATIC_DISPATCH, Some(signature), - Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) + Some(registerType(cleanType(safeGetType(call.getExpressionType)))) ) val args = call.getArguments.toList.map(a => astForNode(a)) - createCallAst(callCpgNode, args) - case fieldRefExpr: ICPPASTFieldReference => + case fieldRefExpr: ICPPASTFieldReference + if safeGetBinding(fieldRefExpr.getFieldName).exists(_.isInstanceOf[ICPPMethod]) => val instanceAst = astForExpression(fieldRefExpr.getFieldOwner) val args = call.getArguments.toList.map(a => astForNode(a)) @@ -130,11 +112,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val name = fieldRefExpr.getFieldName.toString val signature = functionTypeToSignature(functionType) - val classFullName = cleanType(ASTTypeUtil.getType(fieldRefExpr.getFieldOwnerType)) + val classFullName = cleanType(safeGetType(fieldRefExpr.getFieldOwnerType)) val fullName = s"$classFullName.$name:$signature" - fieldRefExpr.getFieldName.resolveBinding() - val method = fieldRefExpr.getFieldName.getBinding().asInstanceOf[ICPPMethod] + val method = fieldRefExpr.getFieldName.getBinding.asInstanceOf[ICPPMethod] val (dispatchType, receiver) = if (method.isVirtual || method.isPureVirtual) { (DispatchTypes.DYNAMIC_DISPATCH, Some(instanceAst)) @@ -148,48 +129,45 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { fullName, dispatchType, Some(signature), - Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) + Some(registerType(cleanType(safeGetType(call.getExpressionType)))) ) - createCallAst(callCpgNode, args, base = Some(instanceAst), receiver) + case _ => + astForCppCallExpressionUntyped(call) } - case classType: ICPPClassType => + case Some(classType: ICPPClassType) if safeGetEvaluation(call).exists(_.isInstanceOf[EvalFunctionCall]) => val evaluation = call.getEvaluation.asInstanceOf[EvalFunctionCall] - val functionType = evaluation.getOverload.getType - val signature = functionTypeToSignature(functionType) - val name = "()" - + val functionType = Try(evaluation.getOverload.getType).toOption + val signature = functionType.map(functionTypeToSignature).getOrElse(X2CpgDefines.UnresolvedSignature) + val name = Defines.OperatorCall classType match { - case closureType: CPPClosureType => - val fullName = s"$name:$signature" - val dispatchType = DispatchTypes.DYNAMIC_DISPATCH - + case _: CPPClosureType => + val fullName = s"$name:$signature" val callCpgNode = callNode( call, code(call), name, fullName, - dispatchType, + DispatchTypes.DYNAMIC_DISPATCH, Some(signature), - Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) + Some(registerType(cleanType(safeGetType(call.getExpressionType)))) ) - val receiverAst = astForExpression(functionNameExpr) val args = call.getArguments.toList.map(a => astForNode(a)) - createCallAst(callCpgNode, args, receiver = Some(receiverAst)) case _ => - val classFullName = cleanType(ASTTypeUtil.getType(classType)) + val classFullName = cleanType(safeGetType(classType)) val fullName = s"$classFullName.$name:$signature" - - val method = evaluation.getOverload.asInstanceOf[ICPPMethod] - val dispatchType = - if (method.isVirtual || method.isPureVirtual) { - DispatchTypes.DYNAMIC_DISPATCH - } else { + val dispatchType = evaluation.getOverload match { + case method: ICPPMethod => + if (method.isVirtual || method.isPureVirtual) { + DispatchTypes.DYNAMIC_DISPATCH + } else { + DispatchTypes.STATIC_DISPATCH + } + case _ => DispatchTypes.STATIC_DISPATCH - } - + } val callCpgNode = callNode( call, code(call), @@ -197,33 +175,24 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { fullName, dispatchType, Some(signature), - Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) + Some(registerType(cleanType(safeGetType(call.getExpressionType)))) ) - val instanceAst = astForExpression(functionNameExpr) val args = call.getArguments.toList.map(a => astForNode(a)) - createCallAst(callCpgNode, args, base = Some(instanceAst), receiver = Some(instanceAst)) } - case _: IProblemType => - astForCppCallExpressionUntyped(call) - case _: IProblemBinding => - astForCppCallExpressionUntyped(call) + case _ => astForCppCallExpressionUntyped(call) } } private def astForCppCallExpressionUntyped(call: ICPPASTFunctionCallExpression): Ast = { - val functionNameExpr = call.getFunctionNameExpression - - functionNameExpr match { + call.getFunctionNameExpression match { case fieldRefExpr: ICPPASTFieldReference => val instanceAst = astForExpression(fieldRefExpr.getFieldOwner) val args = call.getArguments.toList.map(a => astForNode(a)) - - val name = fieldRefExpr.getFieldName.toString - val signature = X2CpgDefines.UnresolvedSignature - val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" - + val name = StringUtils.normalizeSpace(fieldRefExpr.getFieldName.toString) + val signature = X2CpgDefines.UnresolvedSignature + val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" val callCpgNode = callNode( call, code(call), @@ -233,15 +202,12 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { Some(signature), Some(X2CpgDefines.Any) ) - createCallAst(callCpgNode, args, base = Some(instanceAst), receiver = Some(instanceAst)) case idExpr: CPPASTIdExpression => - val args = call.getArguments.toList.map(a => astForNode(a)) - - val name = idExpr.getName.getLastName.toString + val args = call.getArguments.toList.map(a => astForNode(a)) + val name = StringUtils.normalizeSpace(idExpr.getName.getLastName.toString) val signature = X2CpgDefines.UnresolvedSignature val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" - val callCpgNode = callNode( call, code(call), @@ -251,17 +217,14 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { Some(signature), Some(X2CpgDefines.Any) ) - createCallAst(callCpgNode, args) - case other => - // This could either be a pointer or an operator() call we dont know at this point - // but since it is CPP we opt for the later. - val args = call.getArguments.toList.map(a => astForNode(a)) - - val name = "()" + case otherExpr => + // This could either be a pointer or an operator() call we do not know at this point + // but since it is CPP we opt for the latter. + val args = call.getArguments.toList.map(a => astForNode(a)) + val name = Defines.OperatorCall val signature = X2CpgDefines.UnresolvedSignature val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" - val callCpgNode = callNode( call, code(call), @@ -271,24 +234,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { Some(signature), Some(X2CpgDefines.Any) ) - - val instanceAst = astForExpression(functionNameExpr) + val instanceAst = astForExpression(otherExpr) createCallAst(callCpgNode, args, base = Some(instanceAst), receiver = Some(instanceAst)) } } private def astForCCallExpression(call: CASTFunctionCallExpression): Ast = { val functionNameExpr = call.getFunctionNameExpression - val typ = functionNameExpr.getExpressionType - typ match { - case pointerType: CPointerType => - createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) - case functionType: CFunctionType => + Try(functionNameExpr.getExpressionType).toOption match { + case Some(_: CPointerType) => + createPointerCallAst(call, cleanType(safeGetType(call.getExpressionType))) + case Some(_: CFunctionType) => functionNameExpr match { case idExpr: CASTIdExpression => - createCFunctionCallAst(call, idExpr, cleanType(ASTTypeUtil.getType(call.getExpressionType))) + createCFunctionCallAst(call, idExpr, cleanType(safeGetType(call.getExpressionType))) case _ => - createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) + createPointerCallAst(call, cleanType(safeGetType(call.getExpressionType))) } case _ => astForCCallExpressionUntyped(call) @@ -300,50 +261,46 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { idExpr: CASTIdExpression, callTypeFullName: String ): Ast = { - val name = idExpr.getName.getLastName.toString - val signature = "" - + val name = idExpr.getName.getLastName.toString + val signature = "" val dispatchType = DispatchTypes.STATIC_DISPATCH - - val callCpgNode = callNode(call, code(call), name, name, dispatchType, Some(signature), Some(callTypeFullName)) - val args = call.getArguments.toList.map(a => astForNode(a)) - + val callCpgNode = + callNode(call, code(call), name, name, dispatchType, Some(signature), Some(registerType(callTypeFullName))) + val args = call.getArguments.toList.map(a => astForNode(a)) createCallAst(callCpgNode, args) } private def createPointerCallAst(call: IASTFunctionCallExpression, callTypeFullName: String): Ast = { val functionNameExpr = call.getFunctionNameExpression - val name = Defines.operatorPointerCall + val name = Defines.OperatorPointerCall val signature = "" - + val dispatchType = DispatchTypes.DYNAMIC_DISPATCH val callCpgNode = - callNode(call, code(call), name, name, DispatchTypes.DYNAMIC_DISPATCH, Some(signature), Some(callTypeFullName)) - + callNode(call, code(call), name, name, dispatchType, Some(signature), Some(registerType(callTypeFullName))) val args = call.getArguments.toList.map(a => astForNode(a)) val receiverAst = astForExpression(functionNameExpr) createCallAst(callCpgNode, args, receiver = Some(receiverAst)) } private def astForCCallExpressionUntyped(call: CASTFunctionCallExpression): Ast = { - val functionNameExpr = call.getFunctionNameExpression - - functionNameExpr match { - case idExpr: CASTIdExpression => - createCFunctionCallAst(call, idExpr, X2CpgDefines.Any) - case _ => - createPointerCallAst(call, X2CpgDefines.Any) + call.getFunctionNameExpression match { + case idExpr: CASTIdExpression => createCFunctionCallAst(call, idExpr, X2CpgDefines.Any) + case _ => createPointerCallAst(call, X2CpgDefines.Any) } } private def astForCallExpression(call: IASTFunctionCallExpression): Ast = { call match { - case cppCall: ICPPASTFunctionCallExpression => - astForCppCallExpression(cppCall) - case cCall: CASTFunctionCallExpression => - astForCCallExpression(cCall) + case cppCall: ICPPASTFunctionCallExpression => astForCppCallExpression(cppCall) + case cCall: CASTFunctionCallExpression => astForCCallExpression(cCall) } } + private def astForThrowExpression(expression: IASTUnaryExpression): Ast = { + val operand = nullSafeAst(expression.getOperand) + Ast(controlStructureNode(expression, ControlStructureTypes.THROW, code(expression))).withChild(operand) + } + private def astForUnaryExpression(unary: IASTUnaryExpression): Ast = { val operatorMethod = unary.getOperator match { case IASTUnaryExpression.op_prefixIncr => Operators.preIncrement @@ -357,10 +314,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case IASTUnaryExpression.op_sizeof => Operators.sizeOf case IASTUnaryExpression.op_postFixIncr => Operators.postIncrement case IASTUnaryExpression.op_postFixDecr => Operators.postDecrement - case IASTUnaryExpression.op_throw => ".throw" - case IASTUnaryExpression.op_typeid => ".typeOf" - case IASTUnaryExpression.op_bracketedPrimary => ".bracketedPrimary" - case _ => ".unknown" + case IASTUnaryExpression.op_typeid => Defines.OperatorTypeOf + case IASTUnaryExpression.op_bracketedPrimary => Defines.OperatorBracketedPrimary + case _ => Defines.OperatorUnknown } if ( @@ -369,8 +325,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ) { nullSafeAst(unary.getOperand) } else { - val cpgUnary = - callNode(unary, code(unary), operatorMethod, operatorMethod, DispatchTypes.STATIC_DISPATCH) + val cpgUnary = callNode( + unary, + code(unary), + operatorMethod, + operatorMethod, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val operand = nullSafeAst(unary.getOperand) callAst(cpgUnary, List(operand)) } @@ -385,7 +348,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { op == IASTTypeIdExpression.op_alignof || op == IASTTypeIdExpression.op_typeof => val call = - callNode(typeId, code(typeId), Operators.sizeOf, Operators.sizeOf, DispatchTypes.STATIC_DISPATCH) + callNode( + typeId, + code(typeId), + Operators.sizeOf, + Operators.sizeOf, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val arg = astForNode(typeId.getTypeId.getDeclSpecifier) callAst(call, List(arg)) case _ => notHandledYet(typeId) @@ -394,7 +365,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForConditionalExpression(expr: IASTConditionalExpression): Ast = { val name = Operators.conditional - val call = callNode(expr, code(expr), name, name, DispatchTypes.STATIC_DISPATCH) + val call = callNode(expr, code(expr), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val condAst = nullSafeAst(expr.getLogicalConditionExpression) val posAst = nullSafeAst(expr.getPositiveResultExpression) @@ -407,7 +378,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForArrayIndexExpression(arrayIndexExpression: IASTArraySubscriptExpression): Ast = { val name = Operators.indirectIndexAccess val cpgArrayIndexing = - callNode(arrayIndexExpression, code(arrayIndexExpression), name, name, DispatchTypes.STATIC_DISPATCH) + callNode( + arrayIndexExpression, + code(arrayIndexExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val expr = astForExpression(arrayIndexExpression.getArrayExpression) val arg = astForNode(arrayIndexExpression.getArgument) @@ -416,7 +395,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForCastExpression(castExpression: IASTCastExpression): Ast = { val cpgCastExpression = - callNode(castExpression, code(castExpression), Operators.cast, Operators.cast, DispatchTypes.STATIC_DISPATCH) + callNode( + castExpression, + code(castExpression), + Operators.cast, + Operators.cast, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val expr = astForExpression(castExpression.getOperand) val argNode = castExpression.getTypeId @@ -438,9 +425,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForNewExpression(newExpression: ICPPASTNewExpression): Ast = { - val name = ".new" - val cpgNewExpression = - callNode(newExpression, code(newExpression), name, name, DispatchTypes.STATIC_DISPATCH) + val name = Defines.OperatorNew + val cpgNewExpression = callNode( + newExpression, + code(newExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val typeId = newExpression.getTypeId if (newExpression.isArrayAllocation) { @@ -457,7 +451,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForDeleteExpression(delExpression: ICPPASTDeleteExpression): Ast = { val name = Operators.delete val cpgDeleteNode = - callNode(delExpression, code(delExpression), name, name, DispatchTypes.STATIC_DISPATCH) + callNode( + delExpression, + code(delExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val arg = astForExpression(delExpression.getOperand) callAst(cpgDeleteNode, List(arg)) } @@ -465,7 +467,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForTypeIdInitExpression(typeIdInit: IASTTypeIdInitializerExpression): Ast = { val name = Operators.cast val cpgCastExpression = - callNode(typeIdInit, code(typeIdInit), name, name, DispatchTypes.STATIC_DISPATCH) + callNode(typeIdInit, code(typeIdInit), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val typeAst = unknownNode(typeIdInit.getTypeId, code(typeIdInit.getTypeId)) val expr = astForNode(typeIdInit.getInitializer) @@ -473,10 +475,47 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForConstructorExpression(c: ICPPASTSimpleTypeConstructorExpression): Ast = { - val name = c.getDeclSpecifier.toString - val callNode_ = callNode(c, code(c), name, name, DispatchTypes.STATIC_DISPATCH) - val arg = astForNode(c.getInitializer) - callAst(callNode_, List(arg)) + val name = c.getDeclSpecifier.toString + c.getInitializer match { + case l: ICPPASTInitializerList if l.getClauses.forall(_.isInstanceOf[ICPPASTDesignatedInitializer]) => + val node = blockNode(c, Defines.Empty, Defines.Void) + scope.pushNewScope(node) + + val inits = l.getClauses.collect { case i: ICPPASTDesignatedInitializer => i }.toSeq + val calls = inits.flatMap { init => + val designatorIds = init.getDesignators.collect { case d: ICPPASTFieldDesignator => + val name = code(d.getName) + fieldIdentifierNode(d, name, name) + } + designatorIds.map { memberId => + val rhsAst = astForNode(init.getOperand) + val specifierId = + identifierNode(c.getDeclSpecifier, name, name, registerType(cleanType(typeFor(c.getDeclSpecifier)))) + val op = Operators.fieldAccess + val accessCode = s"$name.${memberId.code}" + val ma = callNode(init, accessCode, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) + val maAst = callAst(ma, List(Ast(specifierId), Ast(memberId))) + val assignmentCallNode = + callNode( + c, + s"$accessCode = ${code(init.getOperand)}", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) + callAst(assignmentCallNode, List(maAst, rhsAst)) + } + } + + scope.popScope() + blockAst(node, calls.toList) + case other => + val callNode_ = callNode(c, code(c), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) + val arg = astForNode(other) + callAst(callNode_, List(arg)) + } } private def astForCompoundStatementExpression(compoundExpression: IGNUASTCompoundStatementExpression): Ast = @@ -485,16 +524,37 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private def astForPackExpansionExpression(packExpansionExpression: ICPPASTPackExpansionExpression): Ast = astForExpression(packExpansionExpression.getPattern) + private def astForFoldExpression(foldExpression: CPPASTFoldExpression): Ast = { + def valueFromField[T](obj: Any, fieldName: String): Option[T] = { + // we need this hack because fields are all private at CPPASTExpression + Try { + val field = obj.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(obj).asInstanceOf[T] + }.toOption + } + + val op = ".fold" + val tpe = typeFor(foldExpression) + val callNode_ = + callNode(foldExpression, code(foldExpression), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + + val left = valueFromField[ICPPASTExpression](foldExpression, "fLhs").map(nullSafeAst).getOrElse(Ast()) + val right = valueFromField[ICPPASTExpression](foldExpression, "fRhs").map(nullSafeAst).getOrElse(Ast()) + callAst(callNode_, List(left, right)) + } + protected def astForExpression(expression: IASTExpression): Ast = { val r = expression match { - case lit: IASTLiteralExpression => astForLiteral(lit) - case un: IASTUnaryExpression => astForUnaryExpression(un) - case bin: IASTBinaryExpression => astForBinaryExpression(bin) - case exprList: IASTExpressionList => astForExpressionList(exprList) - case idExpr: IASTIdExpression => astForIdExpression(idExpr) - case call: IASTFunctionCallExpression => astForCallExpression(call) - case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) - case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) + case lit: IASTLiteralExpression => astForLiteral(lit) + case un: IASTUnaryExpression if un.getOperator == IASTUnaryExpression.op_throw => astForThrowExpression(un) + case un: IASTUnaryExpression => astForUnaryExpression(un) + case bin: IASTBinaryExpression => astForBinaryExpression(bin) + case exprList: IASTExpressionList => astForExpressionList(exprList) + case idExpr: IASTIdExpression => astForIdExpression(idExpr) + case call: IASTFunctionCallExpression => astForCallExpression(call) + case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) + case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) case expr: IASTConditionalExpression => astForConditionalExpression(expr) case arr: IASTArraySubscriptExpression => astForArrayIndexExpression(arr) case castExpression: IASTCastExpression => astForCastExpression(castExpression) @@ -502,25 +562,27 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case delExpression: ICPPASTDeleteExpression => astForDeleteExpression(delExpression) case typeIdInit: IASTTypeIdInitializerExpression => astForTypeIdInitExpression(typeIdInit) case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) - case lambdaExpression: ICPPASTLambdaExpression => astForMethodRefForLambda(lambdaExpression) + case lambdaExpression: ICPPASTLambdaExpression => astForLambdaExpression(lambdaExpression) case cExpr: IGNUASTCompoundStatementExpression => astForCompoundStatementExpression(cExpr) case pExpr: ICPPASTPackExpansionExpression => astForPackExpansionExpression(pExpr) + case foldExpression: CPPASTFoldExpression => astForFoldExpression(foldExpression) case _ => notHandledYet(expression) } asChildOfMacroCall(expression, r) } private def astForIdExpression(idExpression: IASTIdExpression): Ast = idExpression.getName match { - case name: CPPASTQualifiedName => astForQualifiedName(name) - case _ => astForIdentifier(idExpression) + case name: CPPASTQualifiedName => astForQualifiedName(name) + case name: ICPPASTName if name.getRawSignature == "constinit" => Ast() + case _ => astForIdentifier(idExpression) } protected def astForStaticAssert(a: ICPPASTStaticAssertDeclaration): Ast = { - val name = "static_assert" - val call = callNode(a, code(a), name, name, DispatchTypes.STATIC_DISPATCH) - val cond = nullSafeAst(a.getCondition) - val messg = nullSafeAst(a.getMessage) - callAst(call, List(cond, messg)) + val name = "static_assert" + val call = callNode(a, code(a), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) + val cond = nullSafeAst(a.getCondition) + val message = nullSafeAst(a.getMessage) + callAst(call, List(cond, message)) } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForFunctionsCreator.scala index 1e7ac31e24f6..5331881c42d3 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForFunctionsCreator.scala @@ -1,29 +1,42 @@ package io.joern.c2cpg.astcreation -import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.utils.NodeBuilders.newModifierNode +import io.shiftleft.codepropertygraph.generated.EvaluationStrategies +import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes} import org.apache.commons.lang3.StringUtils import org.eclipse.cdt.core.dom.ast.* -import org.eclipse.cdt.core.dom.ast.cpp.{ICPPASTLambdaExpression, ICPPFunction} +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTFunctionDefinition +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTLambdaExpression +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTQualifiedName +import org.eclipse.cdt.core.dom.ast.cpp.ICPPBinding import org.eclipse.cdt.core.dom.ast.gnu.c.ICASTKnRFunctionDeclarator -import org.eclipse.cdt.internal.core.dom.parser.c.{CASTFunctionDeclarator, CASTParameterDeclaration, CTypedef} -import org.eclipse.cdt.internal.core.dom.parser.cpp.{ - CPPASTFunctionDeclarator, - CPPASTFunctionDefinition, - CPPASTParameterDeclaration, - CPPFunction -} +import org.eclipse.cdt.internal.core.dom.parser.c.CASTFunctionDeclarator +import org.eclipse.cdt.internal.core.dom.parser.c.CASTParameterDeclaration +import org.eclipse.cdt.internal.core.dom.parser.c.CVariable +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFunctionDeclarator +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFunctionDefinition +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTParameterDeclaration +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPClassType +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPEnumeration +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPStructuredBindingComposite +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPVariable import org.eclipse.cdt.internal.core.model.ASTStringUtil import scala.annotation.tailrec -import scala.collection.mutable +import scala.util.Try trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - private val seenFunctionFullnames = mutable.HashSet.empty[String] + protected def methodDeclarationParentInfo(): (String, String) = { + methodAstParentStack.collectFirst { case t: NewTypeDecl => (t.label, t.fullName) }.getOrElse { + (methodAstParentStack.head.label, methodAstParentStack.head.properties("FULL_NAME").toString) + } + } private def createFunctionTypeAndTypeDecl( node: IASTNode, @@ -57,7 +70,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th Ast(functionBinding).withBindsEdge(parentNode, functionBinding).withRefEdge(functionBinding, method) } - private def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match { + final protected def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match { case arr: IASTArrayDeclarator => parameters(arr.getNestedDeclarator) case decl: CPPASTFunctionDeclarator => decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) case decl: CASTFunctionDeclarator => decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) @@ -70,24 +83,15 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } @tailrec - private def isVariadic(functionNode: IASTNode): Boolean = functionNode match { - case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() - case decl: CASTFunctionDeclarator => decl.takesVarArgs() - case defn: IASTFunctionDefinition => isVariadic(defn.getDeclarator) - case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) - case _ => false + final protected def isVariadic(functionNode: IASTNode): Boolean = functionNode match { + case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() + case decl: CASTFunctionDeclarator => decl.takesVarArgs() + case functionDefinition: IASTFunctionDefinition => isVariadic(functionDefinition.getDeclarator) + case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) + case _ => false } - private def parameterListSignature(func: IASTNode): String = { - val variadic = if (isVariadic(func)) "..." else "" - val elements = parameters(func).map { - case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) - case other => typeForDeclSpecifier(other) - } - s"(${elements.mkString(",")}$variadic)" - } - - private def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = { + protected def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = { parameterNodes.lastOption.foreach { case p: NewMethodParameterIn if isVariadic(func) => p.isVariadic = true @@ -96,80 +100,62 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } - protected def astForMethodRefForLambda(lambdaExpression: ICPPASTLambdaExpression): Ast = { - val filename = fileName(lambdaExpression) - - val returnType = lambdaExpression.getDeclarator match { - case declarator: IASTDeclarator => - declarator.getTrailingReturnType match { - case id: IASTTypeId => typeForDeclSpecifier(id.getDeclSpecifier) - case null => Defines.anyTypeName - } - case null => Defines.anyTypeName - } - val name = nextClosureName() - val fullname = s"${fullName(lambdaExpression)}$name" - val signature = s"$returnType${parameterListSignature(lambdaExpression)}" - val codeString = code(lambdaExpression) - val methodNode_ = methodNode(lambdaExpression, name, codeString, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => - parameterNode(p, i) + private def setVariadicParameterInfo(parameterNodeInfos: Seq[CGlobal.ParameterInfo], func: IASTNode): Unit = { + parameterNodeInfos.lastOption.foreach { + case p: CGlobal.ParameterInfo if isVariadic(func) => + p.isVariadic = true + p.code = s"${p.code}..." + case _ => } - setVariadic(parameterNodes, lambdaExpression) - - scope.popScope() - - val astForLambda = methodAst( - methodNode_, - parameterNodes.map(Ast(_)), - astForMethodBody(Option(lambdaExpression.getBody)), - newMethodReturnNode(lambdaExpression, registerType(returnType)), - newModifierNode(ModifierTypes.LAMBDA) :: Nil - ) - val typeDeclAst = createFunctionTypeAndTypeDecl(lambdaExpression, methodNode_, name, fullname, signature) - Ast.storeInDiffGraph(astForLambda.merge(typeDeclAst), diffGraph) - - Ast(methodRefNode(lambdaExpression, codeString, fullname, methodNode_.astParentFullName)) } protected def astForFunctionDeclarator(funcDecl: IASTFunctionDeclarator): Ast = { - funcDecl.getName.resolveBinding() match { - case function: IFunction => - val returnType = typeForDeclSpecifier(funcDecl.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier) - val fullname = fullName(funcDecl) - val templateParams = templateParameters(funcDecl).getOrElse("") - val signature = - s"$returnType${parameterListSignature(funcDecl)}" - - if (seenFunctionFullnames.add(fullname)) { - val name = shortName(funcDecl) - val codeString = code(funcDecl.getParent) - val filename = fileName(funcDecl) - val methodNode_ = methodNode(funcDecl, name, codeString, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - - val parameterNodes = withIndex(parameters(funcDecl)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, funcDecl) - - scope.popScope() - - val stubAst = - methodStubAst( - methodNode_, - parameterNodes.map(Ast(_)), - newMethodReturnNode(funcDecl, registerType(returnType)) - ) - val typeDeclAst = createFunctionTypeAndTypeDecl(funcDecl, methodNode_, name, fullname, signature) - stubAst.merge(typeDeclAst) - } else { - Ast() + safeGetBinding(funcDecl.getName) match { + case Some(_: IFunction) => + val MethodFullNameInfo(name, fullName, signature, returnType) = methodFullNameInfo(funcDecl) + val codeString = code(funcDecl.getParent) + val filename = fileName(funcDecl) + + val parameterNodeInfos = thisForCPPFunctions(funcDecl) ++ withIndex(parameters(funcDecl)) { (p, i) => + parameterNodeInfo(p, i) } - case field: IField => + setVariadicParameterInfo(parameterNodeInfos, funcDecl) + + val (astParentType, astParentFullName) = methodDeclarationParentInfo() + + val methodInfo = CGlobal.MethodInfo( + name, + code = codeString, + fileName = filename, + returnType = registerType(returnType), + astParentType = astParentType, + astParentFullName = astParentFullName, + lineNumber = line(funcDecl), + columnNumber = column(funcDecl), + lineNumberEnd = lineEnd(funcDecl), + columnNumberEnd = columnEnd(funcDecl), + signature = signature, + offset(funcDecl), + parameter = parameterNodeInfos, + modifier = modifierFor(funcDecl).map(_.modifierType) + ) + registerMethodDeclaration(fullName, methodInfo) + Ast() + case Some(cVariable: CVariable) => + val name = shortName(funcDecl) + val tpe = cleanType(safeGetType(cVariable.getType)) + val codeString = code(funcDecl.getParent) + val node = localNode(funcDecl, name, codeString, registerType(tpe)) + scope.addToScope(name, (node, tpe)) + Ast(node) + case Some(cppVariable: CPPVariable) => + val name = shortName(funcDecl) + val tpe = cleanType(safeGetType(cppVariable.getType)) + val codeString = code(funcDecl.getParent) + val node = localNode(funcDecl, name, codeString, registerType(tpe)) + scope.addToScope(name, (node, tpe)) + Ast(node) + case Some(field: IField) => // TODO create a member for the field // We get here a least for function pointer member declarations in classes like: // class A { @@ -177,75 +163,122 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th // void (*foo)(int); // }; Ast() - case typeDef: ITypedef => + case Some(typeDef: ITypedef) => // TODO handle typeDecl for now we just ignore this. Ast() + case _ => + notHandledYet(funcDecl) } } - private def isCppConstructor(funcDef: IASTFunctionDefinition): Boolean = { + private def modifierFromString(image: String): List[NewModifier] = { + image match { + case "static" => List(newModifierNode(ModifierTypes.STATIC)) + case _ => Nil + } + } + + private def modifierFor(funcDef: IASTFunctionDefinition): List[NewModifier] = { + val constructorModifier = if (isCppConstructor(funcDef)) { + List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) + } else Nil + val visibilityModifier = Try(modifierFromString(funcDef.getSyntax.getImage)).getOrElse(Nil) + constructorModifier ++ visibilityModifier + } + + private def modifierFor(funcDecl: IASTFunctionDeclarator): List[NewModifier] = { + Try(modifierFromString(funcDecl.getParent.getSyntax.getImage)).getOrElse(Nil) + } + + protected def isCppConstructor(funcDef: IASTFunctionDefinition): Boolean = { funcDef match { case cppFunc: CPPASTFunctionDefinition => cppFunc.getMemberInitializers.nonEmpty case _ => false } } + private def thisForCPPFunctions(func: IASTNode): Seq[CGlobal.ParameterInfo] = { + func match { + case cppFunc: ICPPASTFunctionDefinition if !modifierFor(cppFunc).exists(_.modifierType == ModifierTypes.STATIC) => + val maybeOwner = safeGetBinding(cppFunc.getDeclarator.getName) match { + case Some(o: ICPPBinding) if o.getOwner.isInstanceOf[CPPClassType] => + Some(o.getOwner.asInstanceOf[CPPClassType].getQualifiedName.mkString(".")) + case Some(o: ICPPBinding) if o.getOwner.isInstanceOf[CPPEnumeration] => + Some(o.getOwner.asInstanceOf[CPPEnumeration].getQualifiedName.mkString(".")) + case Some(o: ICPPBinding) if o.getOwner.isInstanceOf[CPPStructuredBindingComposite] => + Some(o.getOwner.asInstanceOf[CPPStructuredBindingComposite].getQualifiedName.mkString(".")) + case _ if cppFunc.getDeclarator.getName.isInstanceOf[ICPPASTQualifiedName] => + Some(cppFunc.getDeclarator.getName.asInstanceOf[CPPASTQualifiedName].getQualifier.mkString(".")) + case _ => None + } + maybeOwner.toSeq.map { owner => + new CGlobal.ParameterInfo( + "this", + "this", + 0, + false, + EvaluationStrategies.BY_VALUE, + line(cppFunc), + column(cppFunc), + registerType(s"$owner*") + ) + } + case _ => Seq.empty + } + } + protected def astForFunctionDefinition(funcDef: IASTFunctionDefinition): Ast = { - val filename = fileName(funcDef) - val returnType = if (isCppConstructor(funcDef)) { - typeFor(funcDef.asInstanceOf[CPPASTFunctionDefinition].getMemberInitializers.head.getInitializer) - } else typeForDeclSpecifier(funcDef.getDeclSpecifier) - val name = shortName(funcDef) - val fullname = fullName(funcDef) - val templateParams = templateParameters(funcDef).getOrElse("") - - val signature = - s"$returnType${parameterListSignature(funcDef)}" - seenFunctionFullnames.add(fullname) + val filename = fileName(funcDef) + val MethodFullNameInfo(name, fullName, signature, returnType) = methodFullNameInfo(funcDef) + registerMethodDefinition(fullName) val codeString = code(funcDef) - val methodNode_ = methodNode(funcDef, name, codeString, fullname, Some(signature), filename) + val methodNode_ = methodNode(funcDef, name, codeString, fullName, Some(signature), filename) methodAstParentStack.push(methodNode_) scope.pushNewScope(methodNode_) - val parameterNodes = withIndex(parameters(funcDef)) { (p, i) => + val implicitThisParam = thisForCPPFunctions(funcDef).map { thisParam => + val parameterNode = parameterInNode( + funcDef, + thisParam.name, + thisParam.code, + thisParam.index, + thisParam.isVariadic, + thisParam.evaluationStrategy, + thisParam.typeFullName + ) + scope.addToScope(thisParam.name, (parameterNode, thisParam.typeFullName)) + parameterNode + } + val parameterNodes = implicitThisParam ++ withIndex(parameters(funcDef)) { (p, i) => parameterNode(p, i) } setVariadic(parameterNodes, funcDef) - val modifiers = if (isCppConstructor(funcDef)) { - List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - } else Nil - val astForMethod = methodAst( methodNode_, parameterNodes.map(Ast(_)), astForMethodBody(Option(funcDef.getBody)), - newMethodReturnNode(funcDef, registerType(returnType)), - modifiers = modifiers + methodReturnNode(funcDef, registerType(returnType)), + modifiers = modifierFor(funcDef) ) scope.popScope() methodAstParentStack.pop() - val typeDeclAst = createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullname, signature) + val typeDeclAst = createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullName, signature) astForMethod.merge(typeDeclAst) } - private def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = { + private def parameterNodeInfo(parameter: IASTNode, paramIndex: Int): CGlobal.ParameterInfo = { val (name, codeString, tpe, variadic) = parameter match { case p: CASTParameterDeclaration => - ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), - code(p), - cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), - false - ) + (shortName(p.getDeclarator), code(p), cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), false) case p: CPPASTParameterDeclaration => ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), + shortName(p.getDeclarator), code(p), cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), p.getDeclarator.declaresParameterPack() @@ -262,25 +295,38 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th case other => (code(other), code(other), cleanType(typeForDeclSpecifier(other)), false) } + new CGlobal.ParameterInfo( + name, + codeString, + paramIndex, + variadic, + EvaluationStrategies.BY_VALUE, + lineNumber = line(parameter), + columnNumber = column(parameter), + typeFullName = registerType(tpe) + ) + } + protected def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = { + val parameterInfo = parameterNodeInfo(parameter, paramIndex) val parameterNode = parameterInNode( parameter, - name, - codeString, - paramIndex, - variadic, - EvaluationStrategies.BY_VALUE, - registerType(tpe) + parameterInfo.name, + parameterInfo.code, + parameterInfo.index, + parameterInfo.isVariadic, + parameterInfo.evaluationStrategy, + parameterInfo.typeFullName ) - scope.addToScope(name, (parameterNode, tpe)) + scope.addToScope(parameterInfo.name, (parameterNode, parameterInfo.typeFullName)) parameterNode } - private def astForMethodBody(body: Option[IASTStatement]): Ast = body match { + protected def astForMethodBody(body: Option[IASTStatement]): Ast = body match { case Some(b: IASTCompoundStatement) => astForBlockStatement(b) case Some(b) => astForNode(b) - case None => blockAst(NewBlock()) + case None => blockAst(NewBlock().typeFullName(Defines.Any)) } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForLambdasCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForLambdasCreator.scala new file mode 100644 index 000000000000..23dfd1ca601f --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForLambdasCreator.scala @@ -0,0 +1,249 @@ +package io.joern.c2cpg.astcreation + +import io.joern.c2cpg.astcreation.C2CpgScope.ScopeVariable +import io.joern.x2cpg.Ast +import io.joern.x2cpg.datastructures.Stack.* +import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.utils.NodeBuilders +import io.joern.x2cpg.utils.NodeBuilders.newClosureBindingNode +import io.joern.x2cpg.utils.NodeBuilders.newModifierNode +import io.joern.x2cpg.AstEdge +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.codepropertygraph.generated.nodes.NewBlock +import io.shiftleft.codepropertygraph.generated.nodes.NewClosureBinding +import io.shiftleft.codepropertygraph.generated.nodes.NewIdentifier +import io.shiftleft.codepropertygraph.generated.nodes.NewLocal +import io.shiftleft.codepropertygraph.generated.nodes.NewMethod +import io.shiftleft.codepropertygraph.generated.nodes.NewMethodRef +import io.shiftleft.codepropertygraph.generated.nodes.NewNode +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl +import io.shiftleft.codepropertygraph.generated.EvaluationStrategies +import io.shiftleft.codepropertygraph.generated.nodes.NewBinding +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTCapture +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTLambdaExpression +import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTLambdaExpression.CaptureDefault + +object AstForLambdasCreator { + private case class ClosureBindingEntry(node: ScopeVariable, binding: NewClosureBinding) + private case class LambdaBody(body: Ast, capturedVariables: Seq[ClosureBindingEntry]) +} + +trait AstForLambdasCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => + + import AstForLambdasCreator.* + + private def defineCapturedVariables( + lambdaExpression: ICPPASTLambdaExpression, + lambdaMethodName: String, + capturedVariables: Seq[(ScopeVariable, String)], + filename: String + ): Seq[(ClosureBindingEntry, NewLocal)] = { + capturedVariables + .groupBy(_._1.name) + .map { case (name, variables) => + val (scopeVariable, strategy) = variables.head + val closureBindingId = s"$filename:$lambdaMethodName:$name" + val closureBindingNode = newClosureBindingNode(closureBindingId, name, strategy) + val capturedLocal = localNode( + lambdaExpression, + scopeVariable.name, + scopeVariable.name, + scopeVariable.typeFullName, + Option(closureBindingId) + ) + ClosureBindingEntry(scopeVariable, closureBindingNode) -> capturedLocal + } + .toSeq + } + + private def shouldBeCaptured( + identifier: NewIdentifier, + outerScopeVariableNames: Map[String, ScopeVariable], + bodyAst: Ast + ): Boolean = { + outerScopeVariableNames.contains(identifier.name) && + bodyAst.refEdges.exists(edge => edge.src == identifier && !bodyAst.nodes.contains(edge.dst)) + } + + private def calculateCapturedVariables( + lambdaExpression: ICPPASTLambdaExpression, + bodyAst: Ast, + variablesInScope: Seq[ScopeVariable] + ): Seq[(ScopeVariable, String)] = { + val captureDefault = lambdaExpression.getCaptureDefault + val outerScopeVariableNames = variablesInScope.map(x => x.name -> x).toMap + val capturedVariables = lambdaExpression.getCaptures.toList match { + case captures if captures.isEmpty && captureDefault == CaptureDefault.UNSPECIFIED => + Seq.empty + case captures if captures.isEmpty => + val strategy = captureDefault match { + case CaptureDefault.BY_REFERENCE => EvaluationStrategies.BY_REFERENCE + case _ => EvaluationStrategies.BY_VALUE + } + bodyAst.nodes.collect { + case i: NewIdentifier if shouldBeCaptured(i, outerScopeVariableNames, bodyAst) => + (outerScopeVariableNames(i.name), strategy) + } + case other => + val validCaptures = other.filter(_.getIdentifier != null) + bodyAst.nodes.collect { + case i: NewIdentifier if shouldBeCaptured(i, outerScopeVariableNames, bodyAst) => + val maybeInCaptures = validCaptures.find(c => c.getIdentifier.getRawSignature == i.name) + val strategy = maybeInCaptures match { + case Some(c) if c.isByReference => EvaluationStrategies.BY_REFERENCE + case None if captureDefault == CaptureDefault.BY_REFERENCE => EvaluationStrategies.BY_REFERENCE + case _ => EvaluationStrategies.BY_VALUE + } + (outerScopeVariableNames(i.name), strategy) + } + } + capturedVariables.toSeq + } + + private def fixupRefEdgesForCapturedLocals(bodyAst: Ast, capturedLocals: Seq[NewLocal]): Ast = { + + def needsFixedRefEdge(i: NewIdentifier): Boolean = { + // We only need to fix the ref edge if we would cross method boundaries + bodyAst.refEdges.exists(edge => edge.src == i && !bodyAst.nodes.contains(edge.dst)) + } + + // During the traversal of the lambda body we may ref identifier to some outer param if any. + // This would cross method boundaries, which is invalid but at that time + // we do not know this. Hence, we fix that up here: + var bodyAst_ = bodyAst + val capturedLocalFixCandidates = capturedLocals.flatMap { local => + bodyAst.nodes.collect { case i: NewIdentifier if i.name == local.name && needsFixedRefEdge(i) => (i, local) } + } + capturedLocalFixCandidates.foreach { case (i, local) => + val oldEdge = bodyAst_.refEdges.find(_.src == i) + if (oldEdge.nonEmpty) { + val fixedRefEdges = bodyAst_.refEdges.toList.diff(oldEdge.toList) + val newRefEdges = fixedRefEdges ++ List(AstEdge(i, local)) + bodyAst_ = bodyAst_.copy(refEdges = newRefEdges) + } + } + bodyAst_ + } + + private def astForLambdaBody( + lambdaExpression: ICPPASTLambdaExpression, + lambdaMethodName: String, + variablesInScope: Seq[ScopeVariable], + filename: String + ): LambdaBody = { + var bodyAst = astForMethodBody(Option(lambdaExpression.getBody)) + if (bodyAst.nodes.isEmpty) return LambdaBody(Ast(), Seq.empty) + + val capturedVariables = calculateCapturedVariables(lambdaExpression, bodyAst, variablesInScope) + val bindingsToLocals = defineCapturedVariables(lambdaExpression, lambdaMethodName, capturedVariables, filename) + val capturedLocals = bindingsToLocals.map(_._2) + val closureBindingEntries = bindingsToLocals.map(_._1) + + bodyAst = fixupRefEdgesForCapturedLocals(bodyAst, capturedLocals) + + val capturedLocalsAsts = capturedLocals.map(Ast(_)) + val blockAst = bodyAst.root match { + case Some(b: NewBlock) => + Ast(b).withChildren(capturedLocalsAsts).merge(bodyAst) + case Some(_) => + Ast(blockNode(lambdaExpression.getBody)).withChildren(capturedLocalsAsts).withChild(bodyAst) + case None => Ast() + } + LambdaBody(blockAst, closureBindingEntries) + } + + private def createAndPushLambdaMethod(lambdaExpression: ICPPASTLambdaExpression): (NewMethod, LambdaBody) = { + val MethodFullNameInfo(name, fullName, signature, returnType) = methodFullNameInfo(lambdaExpression) + val filename = fileName(lambdaExpression) + val codeString = code(lambdaExpression) + val variablesInScope = scope.variablesInScope + + val lambdaMethodNode = methodNode(lambdaExpression, name, codeString, fullName, Some(signature), filename) + + methodAstParentStack.push(lambdaMethodNode) + scope.pushNewScope(lambdaMethodNode) + + val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => parameterNode(p, i) } + setVariadic(parameterNodes, lambdaExpression) + val parameterAsts = parameterNodes.map(Ast(_)) + val lambdaBody = astForLambdaBody(lambdaExpression, name, variablesInScope, filename) + + scope.popScope() + methodAstParentStack.pop() + + val isStatic = !lambdaExpression.getCaptures.exists(c => c.capturesThisPointer()) + val returnNode = methodReturnNode(lambdaExpression, registerType(returnType)) + val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) + val staticModifier = Option.when(isStatic)(newModifierNode(ModifierTypes.STATIC)) + val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) + val lambdaModifier = Some(newModifierNode(ModifierTypes.LAMBDA)) + val modifiers = List(virtualModifier, staticModifier, privateModifier, lambdaModifier).flatten.map(Ast(_)) + + val lambdaMethodAst = Ast(lambdaMethodNode) + .withChildren(parameterAsts) + .withChild(lambdaBody.body) + .withChild(Ast(returnNode)) + .withChildren(modifiers) + + val parentNode = methodAstParentStack.collectFirst { case t: NewTypeDecl => t } + Ast.storeInDiffGraph(lambdaMethodAst, diffGraph) + parentNode.foreach { typeDeclNode => + diffGraph.addEdge(typeDeclNode, lambdaMethodNode, EdgeTypes.AST) + } + lambdaMethodNode -> lambdaBody + } + + private def addClosureBindingsToDiffGraph( + bindingEntries: Iterable[ClosureBindingEntry], + methodRef: NewMethodRef + ): Unit = { + bindingEntries.foreach { case ClosureBindingEntry(nodeTypeInfo, closureBinding) => + diffGraph.addNode(closureBinding) + diffGraph.addEdge(closureBinding, nodeTypeInfo.node, EdgeTypes.REF) + diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE) + } + } + + private def createAndPushLambdaTypeDecl( + lambdaExpression: ICPPASTLambdaExpression, + lambdaMethodNode: NewMethod + ): Unit = { + registerType(lambdaMethodNode.fullName) + val (astParentType, astParentFullName) = methodDeclarationParentInfo() + val lambdaTypeDeclNode = typeDeclNode( + lambdaExpression, + lambdaMethodNode.name, + lambdaMethodNode.fullName, + lambdaMethodNode.filename, + lambdaMethodNode.fullName, + astParentType, + astParentFullName, + Seq(registerType(Defines.Function)) + ) + + val functionBinding = NewBinding() + .name(Defines.OperatorCall) + .methodFullName(lambdaMethodNode.fullName) + .signature(lambdaMethodNode.signature) + + val functionBindAst = Ast(functionBinding) + .withBindsEdge(lambdaTypeDeclNode, functionBinding) + .withRefEdge(functionBinding, lambdaMethodNode) + + Ast.storeInDiffGraph(Ast(lambdaTypeDeclNode), diffGraph) + Ast.storeInDiffGraph(functionBindAst, diffGraph) + } + + protected def astForLambdaExpression(lambdaExpression: ICPPASTLambdaExpression): Ast = { + val (lambdaMethodNode, lambdaBody) = createAndPushLambdaMethod(lambdaExpression) + val refCode = lambdaMethodNode.fullName + val refFullName = lambdaMethodNode.fullName + val refTypeFullName = lambdaMethodNode.fullName + val methodRef = methodRefNode(lambdaExpression, refCode, refFullName, refTypeFullName) + addClosureBindingsToDiffGraph(lambdaBody.capturedVariables, methodRef) + createAndPushLambdaTypeDecl(lambdaExpression, lambdaMethodNode) + Ast(methodRef) + } + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForPrimitivesCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForPrimitivesCreator.scala index b1c268e2cb7e..e42a885b7a94 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForPrimitivesCreator.scala @@ -1,96 +1,181 @@ package io.joern.c2cpg.astcreation -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.Defines as X2CpgDefines +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.NewMethod import io.shiftleft.codepropertygraph.generated.nodes.NewMethodRef +import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.internal.core.dom.parser.c.ICInternalBinding import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName import org.eclipse.cdt.internal.core.dom.parser.cpp.ICPPInternalBinding +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFunctionDeclarator +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIdExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPField +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.CPPVisitor +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalMemberAccess import org.eclipse.cdt.internal.core.model.ASTStringUtil +import scala.util.Try + trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => protected def astForComment(comment: IASTComment): Ast = Ast(newCommentNode(comment, code(comment), fileName(comment))) protected def astForLiteral(lit: IASTLiteralExpression): Ast = { - val tpe = cleanType(ASTTypeUtil.getType(lit.getExpressionType)) - Ast(literalNode(lit, code(lit), registerType(tpe))) + val codeString = code(lit) + val tpe = registerType(cleanType(safeGetType(lit.getExpressionType))) + if (codeString == "this") { + val thisIdentifier = identifierNode(lit, "this", "this", tpe) + scope.lookupVariable("this") match { + case Some((variable, _)) => Ast(thisIdentifier).withRefEdge(thisIdentifier, variable) + case _ => Ast(identifierNode(lit, codeString, codeString, tpe)) + } + } else { + Ast(literalNode(lit, codeString, tpe)) + } } private def namesForBinding(binding: ICInternalBinding | ICPPInternalBinding): (Option[String], Option[String]) = { val definition = binding match { - // sadly, there is no common interface defining .getDefinition - case b: ICInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] - case b: ICPPInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] + // sadly, there is no common interface + case b: ICInternalBinding if b.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + Some(b.getDefinition.asInstanceOf[IASTFunctionDeclarator]) + case b: ICPPInternalBinding if b.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + Some(b.getDefinition.asInstanceOf[IASTFunctionDeclarator]) + case b: ICInternalBinding => b.getDeclarations.find(_.isInstanceOf[IASTFunctionDeclarator]) + case b: ICPPInternalBinding => b.getDeclarations.find(_.isInstanceOf[IASTFunctionDeclarator]) + case null => None } - val typeFullName = definition.getParent match { - case d: IASTFunctionDefinition => Some(typeForDeclSpecifier(d.getDeclSpecifier)) - case _ => None + val typeFullName = definition.map(_.getParent) match { + case Some(d: IASTFunctionDefinition) => Some(typeForDeclSpecifier(d.getDeclSpecifier)) + case Some(d: IASTSimpleDeclaration) => Some(typeForDeclSpecifier(d.getDeclSpecifier)) + case _ => None } - (Some(this.fullName(definition)), typeFullName) + (definition.map(fullName), typeFullName) } private def maybeMethodRefForIdentifier(ident: IASTNode): Option[NewMethodRef] = { ident match { case id: IASTIdExpression if id.getName != null => - id.getName.resolveBinding() - val (mayBeFullName, mayBeTypeFullName) = id.getName.getBinding match { - case binding: ICInternalBinding if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + val (mayBeFullName, mayBeTypeFullName) = safeGetBinding(id) match { + case Some(binding: ICInternalBinding) if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => namesForBinding(binding) - case binding: ICPPInternalBinding if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + case Some(binding: ICInternalBinding) + if binding.getDeclarations != null && + binding.getDeclarations.exists(_.isInstanceOf[IASTFunctionDeclarator]) => + namesForBinding(binding) + case Some(binding: ICPPInternalBinding) if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + namesForBinding(binding) + case Some(binding: ICPPInternalBinding) + if binding.getDeclarations != null && + binding.getDeclarations.exists(_.isInstanceOf[CPPASTFunctionDeclarator]) => namesForBinding(binding) case _ => (None, None) } for { fullName <- mayBeFullName typeFullName <- mayBeTypeFullName - } yield methodRefNode(ident, code(ident), fullName, typeFullName) + } yield methodRefNode(ident, code(ident), fullName, registerType(cleanType(typeFullName))) case _ => None } } + private def isInCurrentScope(ident: CPPASTIdExpression, owner: String): Boolean = { + val isInMethodScope = + Try(CPPVisitor.getContainingScope(ident).getScopeName.toString).toOption.exists(s => + s.startsWith(s"$owner::") || s.contains(s"::$owner::") + ) + isInMethodScope || methodAstParentStack.collectFirst { + case typeDecl: NewTypeDecl if typeDecl.fullName == owner => typeDecl + case method: NewMethod if method.fullName.startsWith(owner) => method + }.nonEmpty + } + + private def nameForIdentifier(ident: IASTNode): String = { + ident match { + case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) + case id: IASTName => + val name = ASTStringUtil.getSimpleName(id) + if (name.isEmpty) safeGetBinding(id).map(_.getName).getOrElse(uniqueName("name", "", "")._1) + else name + case _ => code(ident) + } + } + + private def syntheticThisAccess(ident: CPPASTIdExpression, identifierName: String): String | Ast = { + val tpe = ident.getName.getBinding match { + case f: CPPField => safeGetType(f.getType) + case _ => typeFor(ident) + } + Try(ident.getEvaluation).toOption match { + case Some(e: EvalMemberAccess) => + val ownerTypeRaw = cleanType(safeGetType(e.getOwnerType)) + val deref = if (e.isPointerDeref) "*" else "" + val ownerType = registerType(s"$ownerTypeRaw$deref") + if (isInCurrentScope(ident, ownerTypeRaw)) { + scope.lookupVariable("this") match { + case Some((variable, _)) => + val op = Operators.indirectFieldAccess + val code = s"this->$identifierName" + val thisIdentifier = identifierNode(ident, "this", "this", ownerType) + val member = fieldIdentifierNode(ident, identifierName, identifierName) + val ma = + callNode(ident, code, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(registerType(cleanType(tpe)))) + callAst(ma, Seq(Ast(thisIdentifier).withRefEdge(thisIdentifier, variable), Ast(member))) + case None => tpe + } + } else tpe + case _ => tpe + } + } + + private def typeNameForIdentifier(ident: IASTNode, identifierName: String): String | Ast = { + val variableOption = scope.lookupVariable(identifierName) + variableOption match { + case Some((_, variableTypeName)) => variableTypeName + case None if ident.isInstanceOf[IASTName] && ident.asInstanceOf[IASTName].getBinding != null => + val id = ident.asInstanceOf[IASTName] + id.getBinding match { + case v: IVariable => + v.getType match { + case f: IFunctionType => f.getReturnType.toString + case other => other.toString + } + case other => other.getName + } + case None if ident.isInstanceOf[IASTName] => + typeFor(ident.getParent) + case None if ident.isInstanceOf[CPPASTIdExpression] => + syntheticThisAccess(ident.asInstanceOf[CPPASTIdExpression], identifierName) + case None => typeFor(ident) + } + } + protected def astForIdentifier(ident: IASTNode): Ast = { maybeMethodRefForIdentifier(ident) match { case Some(ref) => Ast(ref) case None => - val identifierName = ident match { - case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) - case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty && id.getBinding != null => id.getBinding.getName - case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty => uniqueName("name", "", "")._1 - case _ => code(ident) - } - val variableOption = scope.lookupVariable(identifierName) - val identifierTypeName = variableOption match { - case Some((_, variableTypeName)) => variableTypeName - case None if ident.isInstanceOf[IASTName] && ident.asInstanceOf[IASTName].getBinding != null => - val id = ident.asInstanceOf[IASTName] - id.getBinding match { - case v: IVariable => - v.getType match { - case f: IFunctionType => f.getReturnType.toString - case other => other.toString - } - case other => other.getName + val identifierName = nameForIdentifier(ident) + typeNameForIdentifier(ident, identifierName) match { + case identifierTypeName: String => + val node = identifierNode(ident, identifierName, code(ident), registerType(cleanType(identifierTypeName))) + scope.lookupVariable(identifierName) match { + case Some((variable, _)) => Ast(node).withRefEdge(node, variable) + case _ => Ast(node) } - case None if ident.isInstanceOf[IASTName] => - typeFor(ident.getParent) - case None => typeFor(ident) - } - - val node = identifierNode(ident, identifierName, code(ident), registerType(cleanType(identifierTypeName))) - variableOption match { - case Some((variable, _)) => - Ast(node).withRefEdge(node, variable) - case None => Ast(node) + case ast: Ast => ast } } } protected def astForFieldReference(fieldRef: IASTFieldReference): Ast = { val op = if (fieldRef.isPointerDereference) Operators.indirectFieldAccess else Operators.fieldAccess - val ma = callNode(fieldRef, code(fieldRef), op, op, DispatchTypes.STATIC_DISPATCH) + val ma = callNode(fieldRef, code(fieldRef), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val owner = astForExpression(fieldRef.getFieldOwner) val member = fieldIdentifierNode(fieldRef, fieldRef.getFieldName.toString, fieldRef.getFieldName.toString) callAst(ma, List(owner, Ast(member))) @@ -101,7 +186,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t protected def astForInitializerList(l: IASTInitializerList): Ast = { val op = Operators.arrayInitializer - val initCallNode = callNode(l, code(l), op, op, DispatchTypes.STATIC_DISPATCH) + val initCallNode = callNode(l, code(l), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val MAX_INITIALIZERS = 1000 val clauses = l.getClauses.slice(0, MAX_INITIALIZERS) @@ -111,7 +196,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t val ast = callAst(initCallNode, args) if (l.getClauses.length > MAX_INITIALIZERS) { val placeholder = - literalNode(l, "", Defines.anyTypeName).argumentIndex(MAX_INITIALIZERS) + literalNode(l, "", Defines.Any).argumentIndex(MAX_INITIALIZERS) ast.withChild(Ast(placeholder)).withArgEdge(initCallNode, placeholder) } else { ast @@ -120,7 +205,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t protected def astForQualifiedName(qualId: CPPASTQualifiedName): Ast = { val op = Operators.fieldAccess - val ma = callNode(qualId, code(qualId), op, op, DispatchTypes.STATIC_DISPATCH) + val ma = callNode(qualId, code(qualId), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) def fieldAccesses(names: List[IASTNode], argIndex: Int = -1): Ast = names match { case Nil => Ast() @@ -129,7 +214,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t case head :: tail => val codeString = s"${code(head)}::${tail.map(code).mkString("::")}" val callNode_ = - callNode(head, code(head), op, op, DispatchTypes.STATIC_DISPATCH) + callNode(head, code(head), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) .argumentIndex(argIndex) callNode_.code = codeString val arg1 = astForNode(head) @@ -142,7 +227,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t val owner = if (qualifier != Ast()) { qualifier } else { - Ast(literalNode(qualId.getLastName, "", Defines.anyTypeName)) + Ast(literalNode(qualId.getLastName, "", Defines.Any)) } val member = fieldIdentifierNode( diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForStatementsCreator.scala index badeda55e61a..3f754d937215 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForStatementsCreator.scala @@ -1,28 +1,33 @@ package io.joern.c2cpg.astcreation import io.joern.c2cpg.parser.CdtParser +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.AstNodeNew import io.shiftleft.codepropertygraph.generated.nodes.ExpressionNew +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.NewCall +import io.shiftleft.codepropertygraph.generated.nodes.NewLocal import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.IGNUASTGotoStatement import org.eclipse.cdt.internal.core.dom.parser.c.CASTIfStatement import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIfStatement import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTNamespaceAlias +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTSimpleDeclaration import org.eclipse.cdt.internal.core.model.ASTStringUtil import java.nio.file.Paths +import scala.collection.mutable trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - import io.joern.c2cpg.astcreation.AstCreatorHelper.OptionSafeAst - protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = { val codeString = code(blockStmt) - val blockCode = if (codeString == "{}" || codeString.isEmpty) Defines.empty else codeString - val node = blockNode(blockStmt, blockCode, registerType(Defines.voidTypeName)) + val blockCode = if (codeString == "{}" || codeString.isEmpty) Defines.Empty else codeString + val node = blockNode(blockStmt, blockCode, registerType(Defines.Void)) .order(order) .argumentIndex(order) scope.pushNewScope(node) @@ -36,33 +41,133 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t blockAst(node, childAsts.toList) } + private def hasValidArrayModifier(arrayDecl: IASTArrayDeclarator): Boolean = + arrayDecl.getArrayModifiers.nonEmpty && arrayDecl.getArrayModifiers.forall(_.getConstantExpression != null) + + private def astsForStructuredBindingDeclaration( + struct: ICPPASTStructuredBindingDeclaration, + init: Option[IASTInitializerClause] = None + ): Seq[Ast] = { + def leftAst(astName: IASTNode, localName: String, codeString: String, tpe: String): (NewCall, NewLocal, Ast) = { + val op = Operators.assignment + val assignmentCode = s"$localName = $codeString" + val assignmentCallNode = callNode(astName, assignmentCode, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val localNameNode = localNode(astName, localName, localName, tpe) + scope.addToScope(localName, (localNameNode, tpe)) + val localId = identifierNode(astName, code(astName), code(astName), tpe) + val leftAst = Ast(localId).withRefEdge(localId, localNameNode) + (assignmentCallNode, localNameNode, leftAst) + } + + val initializer = init.getOrElse(struct.getInitializer) + val tmpName = uniqueName("tmp", "", "")._1 + val tpe = registerType(typeFor(initializer)) + val localTmpNode = localNode(struct, tmpName, tmpName, tpe) + scope.addToScope(tmpName, (localTmpNode, tpe)) + + val idNode = identifierNode(struct, tmpName, tmpName, tpe) + val rhsAst = astForNode(initializer) + val op = Operators.assignment + val assignmentCode = s"$tmpName = ${code(initializer)}" + val assignmentCallNode = callNode(struct, assignmentCode, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val assignmentCallAst = callAst(assignmentCallNode, List(Ast(idNode).withRefEdge(idNode, localTmpNode), rhsAst)) + + val accessAsts = if typeFor(initializer).endsWith("]") then { + struct.getNames.zipWithIndex.flatMap { case (astName, index) => + val localName = code(astName) + val tpe = registerType(typeFor(astName)) + val codeString = s"$tmpName[$index]" + val (assignmentCallNode, localNode, lhsAst) = leftAst(astName, localName, codeString, tpe) + val op = Operators.indexAccess + val arrayIndexCallNode = callNode(astName, codeString, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val idNode = identifierNode(astName, tmpName, tmpName, tpe) + val indexNode = literalNode(astName, index.toString, registerType("int")) + val arrayIndexCallAst = callAst(arrayIndexCallNode, List(Ast(idNode), Ast(indexNode))) + Seq(Ast(localNode), Ast(assignmentCallNode).withChildren(List(lhsAst, arrayIndexCallAst))) + } + } else { + struct.getNames.flatMap { astName => + val localName = code(astName) + val tpe = registerType(typeFor(astName)) + val codeString = s"$tmpName.$localName" + val (assignmentCallNode, localNode, lhsAst) = leftAst(astName, localName, codeString, tpe) + val op = Operators.memberAccess + val memberAccessCallNode = callNode(astName, codeString, op, op, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val idNode = identifierNode(astName, tmpName, tmpName, tpe) + val fieldIdNode = fieldIdentifierNode(astName, localName, localName) + val memberAccessCallAst = callAst(memberAccessCallNode, List(Ast(idNode), Ast(fieldIdNode))) + Seq(Ast(localNode), Ast(assignmentCallNode).withChildren(List(lhsAst, memberAccessCallAst))) + } + } + + Seq(Ast(localTmpNode), assignmentCallAst) ++ accessAsts + } + + private def isCoroutineCall(decl: IASTDeclaration): Boolean = { + decl.getRawSignature.startsWith("co_yield ") || decl.getRawSignature.startsWith("co_await ") + } + + /** CDT is unable to parse co_yield or co_await calls into actual AST elements. Hence, this hack to recover the + * structure from CPPASTSimpleDeclaration. + */ + private def astForCoroutineCall(decl: CPPASTSimpleDeclaration): Ast = { + val op = decl.getRawSignature match { + case s if s.startsWith("co_yield ") => ".yield" + case _ => ".await" + } + val node = callNode(decl, code(decl), op, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = decl.getDeclarators.zipWithIndex.map { case (d, i) => astForDeclarator(decl, d, i) } + callAst(node, argAsts.toSeq) + } + + private def astsForIASTSimpleDeclaration(simpleDecl: IASTSimpleDeclaration): Seq[Ast] = { + val declAsts = simpleDecl.getDeclarators.zipWithIndex.map { + case (d: IASTFunctionDeclarator, _) => astForFunctionDeclarator(d) + case (d, i) => astForDeclarator(simpleDecl, d, i) + } + val arrayModCallsAsts = simpleDecl.getDeclarators + .collect { case d: IASTArrayDeclarator if hasValidArrayModifier(d) => d } + .map { d => + val name = Operators.alloc + val tpe = registerType(typeFor(d)) + val codeString = code(d) + val allocCallNode = callNode(d, codeString, name, name, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val allocCallAst = callAst(allocCallNode, d.getArrayModifiers.toIndexedSeq.map(astForNode)) + val operatorName = Operators.assignment + val assignmentCallNode = + callNode(d, codeString, operatorName, operatorName, DispatchTypes.STATIC_DISPATCH, None, Some(tpe)) + val left = astForNode(d.getName) + callAst(assignmentCallNode, List(left, allocCallAst)) + } + val initCallsAsts = simpleDecl.getDeclarators.filter(_.getInitializer != null).map { d => + astForInitializer(d, d.getInitializer) + } + val asts = Seq.from(declAsts ++ arrayModCallsAsts ++ initCallsAsts) + setArgumentIndices(asts) + asts + } + private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] = decl.getDeclaration match { - case simplDecl: IASTSimpleDeclaration - if simplDecl.getDeclarators.headOption.exists(_.isInstanceOf[IASTFunctionDeclarator]) => - Seq(astForFunctionDeclarator(simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator])) - case simplDecl: IASTSimpleDeclaration => - val locals = - simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => astForDeclarator(simplDecl, d, i) } - val calls = - simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d => - astForInitializer(d, d.getInitializer) - } - locals ++ calls - case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) - case usingDeclaration: ICPPASTUsingDeclaration => handleUsingDeclaration(usingDeclaration) - case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) - case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) - case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) - case _: ICPPASTUsingDirective => Seq.empty - case declaration => Seq(astForNode(declaration)) + case struct: ICPPASTStructuredBindingDeclaration => astsForStructuredBindingDeclaration(struct) + case declStmt: CPPASTSimpleDeclaration if isCoroutineCall(declStmt) => Seq(astForCoroutineCall(declStmt)) + case simpleDecl: IASTSimpleDeclaration => astsForIASTSimpleDeclaration(simpleDecl) + case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) + case usingDeclaration: ICPPASTUsingDeclaration => handleUsingDeclaration(usingDeclaration) + case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) + case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) + case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) + case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) + case _: ICPPASTUsingDirective => Seq.empty + case declaration => astsForDeclaration(declaration) } private def astForReturnStatement(ret: IASTReturnStatement): Ast = { val cpgReturn = returnNode(ret, code(ret)) - val expr = nullSafeAst(ret.getReturnValue) - Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root) + nullSafeAst(ret.getReturnValue) match { + case retAst if retAst.root.isDefined => Ast(cpgReturn).withChild(retAst).withArgEdge(cpgReturn, retAst.root.get) + case _ => Ast(cpgReturn) + } } private def astForBreakStatement(br: IASTBreakStatement): Ast = { @@ -103,12 +208,22 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t controlStructureAst(doNode, Option(conditionAst), bodyAst, placeConditionLast = true) } - private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = { - val code = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})" - val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code) + private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Seq[Ast] = { + val initAsts = switchStmt match { + case s: ICPPASTSwitchStatement => + nullSafeAst(s.getInitializerStatement) ++ nullSafeAst(s.getControllerDeclaration) + case _ => + Seq.empty + } + + val codeString = code(switchStmt) + val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, codeString) val conditionAst = astForConditionExpression(switchStmt.getControllerExpression) val stmtAsts = nullSafeAst(switchStmt.getBody) - controlStructureAst(switchNode, Option(conditionAst), stmtAsts) + + val finalAsts = initAsts :+ controlStructureAst(switchNode, Option(conditionAst), stmtAsts) + setArgumentIndices(finalAsts) + finalAsts } private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = { @@ -145,12 +260,12 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t val r = statement match { case expr: IASTExpressionStatement => Seq(astForExpression(expr.getExpression)) case block: IASTCompoundStatement => Seq(astForBlockStatement(block, argIndex)) - case ifStmt: IASTIfStatement => Seq(astForIf(ifStmt)) + case ifStmt: IASTIfStatement => astForIf(ifStmt) case whileStmt: IASTWhileStatement => Seq(astForWhile(whileStmt)) case forStmt: IASTForStatement => Seq(astForFor(forStmt)) case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt)) case doStmt: IASTDoStatement => Seq(astForDoStatement(doStmt)) - case switchStmt: IASTSwitchStatement => Seq(astForSwitchStatement(switchStmt)) + case switchStmt: IASTSwitchStatement => astForSwitchStatement(switchStmt) case ret: IASTReturnStatement => Seq(astForReturnStatement(ret)) case br: IASTBreakStatement => Seq(astForBreakStatement(br)) case cont: IASTContinueStatement => Seq(astForContinueStatement(cont)) @@ -174,7 +289,8 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t // We only handle un-parsable macros here for now val isFromMacroExpansion = statement.getProblem.getNodeLocations.exists(_.isInstanceOf[IASTMacroExpansionLocation]) val asts = if (isFromMacroExpansion) { - new CdtParser(config).parse(statement.getRawSignature, Paths.get(statement.getContainingFilename)) match + new CdtParser(config, headerFileFinder, mutable.LinkedHashSet.empty) + .parse(statement.getRawSignature, Paths.get(statement.getContainingFilename)) match case Some(node) => node.getDeclarations.toIndexedSeq.flatMap(astsForDeclaration) case None => Seq.empty } else { @@ -193,7 +309,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t private def astForConditionExpression(expression: IASTExpression, explicitArgumentIndex: Option[Int] = None): Ast = { val ast = expression match { case exprList: IASTExpressionList => - val compareAstBlock = blockNode(expression, Defines.empty, registerType(Defines.voidTypeName)) + val compareAstBlock = blockNode(expression, Defines.Empty, registerType(Defines.Void)) scope.pushNewScope(compareAstBlock) val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst) setArgumentIndices(compareBlockAstChildren) @@ -217,28 +333,35 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t val code = s"for ($codeInit$codeCond;$codeIter)" val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName)) - scope.pushNewScope(initAstBlock) - val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList) - scope.popScope() - - val compareAst = astForConditionExpression(forStmt.getConditionExpression, Option(2)) - val updateAst = nullSafeAst(forStmt.getIterationExpression, 3) - val bodyAsts = nullSafeAst(forStmt.getBody, 4) - forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts) + val (localAsts, initAsts) = + nullSafeAst(forStmt.getInitializerStatement).partition(_.root.exists(_.isInstanceOf[NewLocal])) + setArgumentIndices(initAsts) + val compareAst = astForConditionExpression(forStmt.getConditionExpression) + val updateAst = nullSafeAst(forStmt.getIterationExpression) + val bodyAsts = nullSafeAst(forStmt.getBody) + forAst(forNode, localAsts, initAsts, Seq(compareAst), Seq(updateAst), bodyAsts) } private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = { val codeDecl = nullSafeCode(forStmt.getDeclaration) val codeInit = nullSafeCode(forStmt.getInitializerClause) - - val code = s"for ($codeDecl:$codeInit)" - val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - - val initAst = astForNode(forStmt.getInitializerClause) - val declAst = astsForDeclaration(forStmt.getDeclaration) - val stmtAst = nullSafeAst(forStmt.getBody) - controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst) + val code = s"for ($codeDecl:$codeInit)" + val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) + forStmt.getDeclaration match { + case declaration: ICPPASTStructuredBindingDeclaration => + val (localAsts, initAsts) = astsForStructuredBindingDeclaration(declaration, Some(forStmt.getInitializerClause)) + .partition(_.root.exists(_.isInstanceOf[NewLocal])) + setArgumentIndices(initAsts) + val bodyAst = nullSafeAst(forStmt.getBody) + forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst) + case _ => + val init = astForNode(forStmt.getInitializerClause) + val declAsts = astsForDeclaration(forStmt.getDeclaration) + setArgumentIndices(init +: declAsts) + val (localAsts, initAsts) = (init +: declAsts).partition(_.root.exists(_.isInstanceOf[NewLocal])) + val bodyAst = nullSafeAst(forStmt.getBody) + forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst) + } } private def astForWhile(whileStmt: IASTWhileStatement): Ast = { @@ -254,28 +377,29 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t ) } - private def astForIf(ifStmt: IASTIfStatement): Ast = { - val (code, conditionAst) = ifStmt match { + private def astForIf(ifStmt: IASTIfStatement): Seq[Ast] = { + val initAsts = ifStmt match { + case s: ICPPASTIfStatement => nullSafeAst(s.getInitializerStatement) + case _ => Seq.empty + } + val conditionAst = ifStmt match { case s @ (_: CASTIfStatement | _: CPPASTIfStatement) if s.getConditionExpression != null => - val c = s"if (${nullSafeCode(s.getConditionExpression)})" - val compareAst = astForConditionExpression(s.getConditionExpression) - (c, compareAst) + astForConditionExpression(s.getConditionExpression) case s: CPPASTIfStatement if s.getConditionExpression == null => - val c = s"if (${nullSafeCode(s.getConditionDeclaration)})" - val exprBlock = blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName) + val exprBlock = blockNode(s.getConditionDeclaration, Defines.Empty, Defines.Void) scope.pushNewScope(exprBlock) val a = astsForDeclaration(s.getConditionDeclaration) setArgumentIndices(a) scope.popScope() - (c, blockAst(exprBlock, a.toList)) + blockAst(exprBlock, a.toList) } - val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code) + val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code(ifStmt)) val thenAst = ifStmt.getThenClause match { case block: IASTCompoundStatement => astForBlockStatement(block) case other if other != null => - val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + val thenBlock = blockNode(other, Defines.Empty, Defines.Void) scope.pushNewScope(thenBlock) val a = astsForStatement(other) setArgumentIndices(a) @@ -291,7 +415,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t Ast(elseNode).withChild(elseAst) case other if other != null => val elseNode = controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") - val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + val elseBlock = blockNode(other, Defines.Empty, Defines.Void) scope.pushNewScope(elseBlock) val a = astsForStatement(other) setArgumentIndices(a) @@ -299,6 +423,9 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) case _ => Ast() } - controlStructureAst(ifNode, Option(conditionAst), Seq(thenAst, elseAst)) + + val finalAsts = initAsts :+ controlStructureAst(ifNode, Option(conditionAst), Seq(thenAst, elseAst)) + setArgumentIndices(finalAsts) + finalAsts } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForTypesCreator.scala index 8bfbf0a2e05a..e98dd2322da7 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstForTypesCreator.scala @@ -3,11 +3,13 @@ package io.joern.c2cpg.astcreation import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.Defines as X2CpgDefines import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTAliasDeclaration import org.eclipse.cdt.internal.core.model.ASTStringUtil import io.joern.x2cpg.datastructures.Stack.* +import org.apache.commons.lang3.StringUtils trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -16,10 +18,9 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: case _ => false } - private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = - code(decl).startsWith("typedef") + private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = decl.getRawSignature.startsWith("typedef") - protected def templateParameters(e: IASTNode): Option[String] = { + private def templateParameters(e: IASTNode): Option[String] = { val templateDeclaration = e match { case _: IASTElaboratedTypeSpecifier | _: IASTFunctionDeclarator | _: IASTCompositeTypeSpecifier if e.getParent != null => @@ -34,11 +35,10 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: } private def astForNamespaceDefinition(namespaceDefinition: ICPPASTNamespaceDefinition): Ast = { - val (name, fullname) = - uniqueName("namespace", namespaceDefinition.getName.getLastName.toString, fullName(namespaceDefinition)) - val codeString = code(namespaceDefinition) + val TypeFullNameInfo(name, fullName) = typeFullNameInfo(namespaceDefinition) + val codeString = code(namespaceDefinition) val cpgNamespace = - newNamespaceBlockNode(namespaceDefinition, name, fullname, codeString, fileName(namespaceDefinition)) + newNamespaceBlockNode(namespaceDefinition, name, fullName, codeString, fileName(namespaceDefinition)) scope.pushNewScope(cpgNamespace) val childrenAsts = namespaceDefinition.getDeclarations.flatMap { decl => @@ -52,66 +52,105 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: } protected def astForNamespaceAlias(namespaceAlias: ICPPASTNamespaceAlias): Ast = { - val name = ASTStringUtil.getSimpleName(namespaceAlias.getAlias) - val fullname = fullName(namespaceAlias) - + val TypeFullNameInfo(name, fullName) = typeFullNameInfo(namespaceAlias) if (!isQualifiedName(name)) { - usingDeclarationMappings.put(name, fullname) + usingDeclarationMappings.put(name, fullName) } - val codeString = code(namespaceAlias) - val cpgNamespace = newNamespaceBlockNode(namespaceAlias, name, fullname, codeString, fileName(namespaceAlias)) + val cpgNamespace = newNamespaceBlockNode(namespaceAlias, name, fullName, codeString, fileName(namespaceAlias)) Ast(cpgNamespace) } + private def isAssignmentFromBrokenMacro(declaration: IASTSimpleDeclaration, declarator: IASTDeclarator): Boolean = + declaration.getParent.isInstanceOf[IASTTranslationUnit] && + declarator.getInitializer.isInstanceOf[IASTEqualsInitializer] + protected def astForDeclarator(declaration: IASTSimpleDeclaration, declarator: IASTDeclarator, index: Int): Ast = { - val name = ASTStringUtil.getSimpleName(declarator.getName) + val name = shortName(declarator) declaration match { case d if isTypeDef(d) && shortName(d.getDeclSpecifier).nonEmpty => val filename = fileName(declaration) - val tpe = registerType(typeFor(declarator)) - Ast(typeDeclNode(declarator, name, registerType(name), filename, code(d), alias = Option(tpe))) + val typeDefName = if (name.isEmpty) { + safeGetBinding(declarator.getName).map(b => registerType(b.getName)) + } else { + Option(registerType(name)) + } + val tpe = registerType(typeFor(declarator)) + Ast( + typeDeclNode( + declarator, + typeDefName.getOrElse(name), + typeDefName.getOrElse(name), + filename, + code(d), + alias = Option(tpe) + ) + ) case d if parentIsClassDef(d) => val tpe = declarator match { - case _: IASTArrayDeclarator => registerType(typeFor(declarator)) - case _ => registerType(typeForDeclSpecifier(declaration.getDeclSpecifier)) + case _: IASTArrayDeclarator => registerType(cleanType(typeFor(declarator))) + case _ => registerType(cleanType(typeForDeclSpecifier(declaration.getDeclSpecifier, index = index))) } Ast(memberNode(declarator, name, code(declarator), tpe)) - case _ if declarator.isInstanceOf[IASTArrayDeclarator] => - val tpe = registerType(typeFor(declarator)) - val codeTpe = typeFor(declarator, stripKeywords = false) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) - scope.addToScope(name, (node, tpe)) - Ast(node) + case d if isAssignmentFromBrokenMacro(d, declarator) && scope.lookupVariable(name).nonEmpty => + Ast() case _ => - val tpe = registerType( - cleanType(typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = true, index)) - ) - val codeTpe = typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = false, index) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) + val tpe = declarator match { + case arrayDecl: IASTArrayDeclarator => registerType(cleanType(typeFor(arrayDecl))) + case _ => registerType(cleanType(typeForDeclSpecifier(declaration.getDeclSpecifier, index = index))) + } + val code = codeForDeclarator(declaration, declarator) + val node = localNode(declarator, name, code, tpe) scope.addToScope(name, (node, tpe)) Ast(node) } + } + private def codeForDeclarator(declaration: IASTSimpleDeclaration, declarator: IASTDeclarator): String = { + val specCode = declaration.getDeclSpecifier.getRawSignature + val declCodeRaw = declarator.getRawSignature + val declCode = declarator.getInitializer match { + case null => declCodeRaw + case _ => declCodeRaw.replace(declarator.getInitializer.getRawSignature, "") + } + val normalizedCode = StringUtils.normalizeSpace(s"$specCode $declCode") + normalizedCode.strip() } protected def astForInitializer(declarator: IASTDeclarator, init: IASTInitializer): Ast = init match { case i: IASTEqualsInitializer => val operatorName = Operators.assignment val callNode_ = - callNode(declarator, code(declarator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) + callNode( + declarator, + code(declarator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val left = astForNode(declarator.getName) val right = astForNode(i.getInitializerClause) callAst(callNode_, List(left, right)) case i: ICPPASTConstructorInitializer => - val name = ASTStringUtil.getSimpleName(declarator.getName) - val callNode_ = callNode(declarator, code(declarator), name, name, DispatchTypes.STATIC_DISPATCH) - val args = i.getArguments.toList.map(x => astForNode(x)) + val name = ASTStringUtil.getSimpleName(declarator.getName) + val callNode_ = + callNode(declarator, code(declarator), name, name, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) + val args = i.getArguments.toList.map(x => astForNode(x)) callAst(callNode_, args) case i: IASTInitializerList => val operatorName = Operators.assignment val callNode_ = - callNode(declarator, code(declarator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) + callNode( + declarator, + code(declarator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val left = astForNode(declarator.getName) val right = astForNode(i) callAst(callNode_, List(left, right)) @@ -151,7 +190,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: protected def astForASMDeclaration(asm: IASTASMDeclaration): Ast = Ast(unknownNode(asm, code(asm))) private def astForStructuredBindingDeclaration(decl: ICPPASTStructuredBindingDeclaration): Ast = { - val node = blockNode(decl, Defines.empty, Defines.voidTypeName) + val node = blockNode(decl, Defines.Empty, Defines.Void) scope.pushNewScope(node) val childAsts = decl.getNames.toList.map { name => astForNode(name) @@ -192,7 +231,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: case _ if declaration.getDeclarators.isEmpty => Seq(astForNode(declaration)) } case alias: CPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case functDef: IASTFunctionDefinition => Seq(astForFunctionDefinition(functDef)) + case functionDefinition: IASTFunctionDefinition => Seq(astForFunctionDefinition(functionDefinition)) case namespaceAlias: ICPPASTNamespaceAlias => Seq(astForNamespaceAlias(namespaceAlias)) case namespaceDefinition: ICPPASTNamespaceDefinition => Seq(astForNamespaceDefinition(namespaceDefinition)) case a: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(a)) @@ -212,8 +251,9 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: case d: IASTDeclarator if d.getInitializer != null => astForInitializer(d, d.getInitializer) case arrayDecl: IASTArrayDeclarator => - val op = Operators.arrayInitializer - val initCallNode = callNode(arrayDecl, code(arrayDecl), op, op, DispatchTypes.STATIC_DISPATCH) + val op = Operators.arrayInitializer + val initCallNode = + callNode(arrayDecl, code(arrayDecl), op, op, DispatchTypes.STATIC_DISPATCH, None, Some(X2CpgDefines.Any)) val initArgs = arrayDecl.getArrayModifiers.toList.filter(m => m.getConstantExpression != null).map(astForNode) callAst(initCallNode, initArgs) @@ -229,31 +269,35 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: astsForDeclaration(d) } + private def filterNameAlias( + nameAlias: Option[String], + nameWithTemplateParams: Option[String], + fullName: String + ): Option[String] = { + (nameAlias.toList ++ nameWithTemplateParams.toList).filter(n => n != fullName).distinct.headOption + } + private def astsForCompositeType(typeSpecifier: IASTCompositeTypeSpecifier, decls: List[IASTDeclarator]): Seq[Ast] = { val filename = fileName(typeSpecifier) val declAsts = decls.zipWithIndex.map { case (d, i) => astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) } - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) match { - case n if n.isEmpty => lastNameOfQualifiedName(fullname) - case other => other - } - val codeString = code(typeSpecifier) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val TypeFullNameInfo(name, fullName) = typeFullNameInfo(typeSpecifier) + val codeString = code(typeSpecifier) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullName$t")) + val alias = filterNameAlias(nameAlias, nameWithTemplateParams, fullName) val typeDecl = typeSpecifier match { case cppClass: ICPPASTCompositeTypeSpecifier => val baseClassList = cppClass.getBaseSpecifiers.toSeq.map(s => registerType(s.getNameSpecifier.toString)) - typeDeclNode(typeSpecifier, name, fullname, filename, codeString, inherits = baseClassList, alias = alias) + typeDeclNode(typeSpecifier, name, fullName, filename, codeString, inherits = baseClassList, alias = alias) case _ => - typeDeclNode(typeSpecifier, name, fullname, filename, codeString, alias = alias) + typeDeclNode(typeSpecifier, name, fullName, filename, codeString, alias = alias) } methodAstParentStack.push(typeDecl) @@ -270,9 +314,9 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: } else { val init = staticInitMethodAst( calls, - s"$fullname:${io.joern.x2cpg.Defines.StaticInitMethodName}", + s"$fullName.${io.joern.x2cpg.Defines.StaticInitMethodName}", None, - Defines.anyTypeName, + Defines.Any, Some(filename), lineNumber, columnNumber @@ -289,16 +333,11 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: val declAsts = decls.zipWithIndex.map { case (d, i) => astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) } - - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption - - val typeDecl = - typeDeclNode(typeSpecifier, name, fullname, filename, code(typeSpecifier), alias = alias) - + val TypeFullNameInfo(name, fullName) = typeFullNameInfo(typeSpecifier) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullName$t")) + val alias = filterNameAlias(nameAlias, nameWithTemplateParams, fullName) + val typeDecl = typeDeclNode(typeSpecifier, name, fullName, filename, code(typeSpecifier), alias = alias) Ast(typeDecl) +: declAsts } @@ -318,7 +357,15 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: if (enumerator.getValue != null) { val operatorName = Operators.assignment val callNode_ = - callNode(enumerator, code(enumerator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) + callNode( + enumerator, + code(enumerator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH, + None, + Some(X2CpgDefines.Any) + ) val left = astForNode(enumerator.getName) val right = astForNode(enumerator.getValue) val ast = callAst(callNode_, List(left, right)) @@ -334,15 +381,15 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) } - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val (name, fullname) = - uniqueName("enum", ASTStringUtil.getSimpleName(typeSpecifier.getName), fullName(typeSpecifier)) - val alias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val TypeFullNameInfo(name, fullName) = typeFullNameInfo(typeSpecifier) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val alias = filterNameAlias(nameAlias, None, fullName) val (deAliasedName, deAliasedFullName, newAlias) = if (name.contains("anonymous_enum") && alias.isDefined) { - (alias.get, fullname.substring(0, fullname.indexOf("anonymous_enum")) + alias.get, None) - } else { (name, fullname, alias) } + (alias.get, fullName.substring(0, fullName.indexOf("anonymous_enum")) + alias.get, None) + } else { (name, fullName, alias) } val typeDecl = typeDeclNode( @@ -370,7 +417,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: calls, s"$deAliasedFullName:${io.joern.x2cpg.Defines.StaticInitMethodName}", None, - Defines.anyTypeName, + Defines.Any, Some(filename), lineNumber, columnNumber diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstNodeBuilder.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstNodeBuilder.scala index f3d7316835f2..5499f2d655a4 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstNodeBuilder.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/AstNodeBuilder.scala @@ -1,12 +1,12 @@ package io.joern.c2cpg.astcreation -import io.joern.x2cpg.utils.NodeBuilders.{newMethodReturnNode => newMethodReturnNode_} -import io.shiftleft.codepropertygraph.generated.nodes._ -import org.eclipse.cdt.core.dom.ast.{IASTLabelStatement, IASTNode} -import org.eclipse.cdt.core.dom.ast.IASTPreprocessorIncludeStatement +import io.shiftleft.codepropertygraph.generated.nodes.* +import org.eclipse.cdt.core.dom.ast.IASTLabelStatement +import org.eclipse.cdt.core.dom.ast.IASTNode import org.eclipse.cdt.internal.core.model.ASTStringUtil trait AstNodeBuilder { this: AstCreator => + protected def newCommentNode(node: IASTNode, code: String, filename: String): NewComment = { NewComment().code(code).filename(filename).lineNumber(line(node)).columnNumber(column(node)) } @@ -14,7 +14,7 @@ trait AstNodeBuilder { this: AstCreator => protected def newNamespaceBlockNode( node: IASTNode, name: String, - fullname: String, + fullName: String, code: String, filename: String ): NewNamespaceBlock = { @@ -24,12 +24,7 @@ trait AstNodeBuilder { this: AstCreator => .columnNumber(column(node)) .filename(filename) .name(name) - .fullName(fullname) - } - - // TODO: We should get rid of this method as its being used at multiple places and use it from x2cpg/AstNodeBuilder "methodReturnNode" - protected def newMethodReturnNode(node: IASTNode, typeFullName: String): NewMethodReturn = { - newMethodReturnNode_(typeFullName, None, line(node), column(node)) + .fullName(fullName) } protected def newJumpTargetNode(node: IASTNode): NewJumpTarget = { diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/C2CpgScope.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/C2CpgScope.scala new file mode 100644 index 000000000000..ae8434cd3164 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/C2CpgScope.scala @@ -0,0 +1,40 @@ +package io.joern.c2cpg.astcreation + +import io.joern.x2cpg.datastructures.Scope +import io.shiftleft.codepropertygraph.generated.nodes.NewLocal +import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn +import io.shiftleft.codepropertygraph.generated.nodes.NewNode + +object C2CpgScope { + private type NewVariableNode = NewLocal | NewMethodParameterIn + + sealed trait ScopeVariable { + def node: NewVariableNode + def typeFullName: String + def name: String + } + + private final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable { + val typeFullName: String = node.typeFullName + val name: String = node.name + } + + private final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable { + val typeFullName: String = node.typeFullName + val name: String = node.name + } + +} + +class C2CpgScope extends Scope[String, (NewNode, String), NewNode] { + + import C2CpgScope.* + + def variablesInScope: List[ScopeVariable] = { + stack.reverse.flatMap(_.variables.values.map(_._1)).collect { + case local: NewLocal => ScopeLocal(local) + case parameter: NewMethodParameterIn => ScopeParameter(parameter) + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/CGlobal.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/CGlobal.scala new file mode 100644 index 000000000000..bb417bd27a9a --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/CGlobal.scala @@ -0,0 +1,43 @@ +package io.joern.c2cpg.astcreation + +import io.joern.x2cpg.datastructures.Global +import java.util.concurrent.ConcurrentHashMap + +object CGlobal { + + final case class MethodInfo( + name: String, + code: String, + fileName: String, + returnType: String, + astParentType: String, + astParentFullName: String, + lineNumber: Option[Int], + columnNumber: Option[Int], + lineNumberEnd: Option[Int], + columnNumberEnd: Option[Int], + signature: String, + offset: Option[(Int, Int)], + parameter: Seq[ParameterInfo], + modifier: Seq[String] + ) + final class ParameterInfo( + val name: String, + var code: String, + val index: Int, + var isVariadic: Boolean, + val evaluationStrategy: String, + val lineNumber: Option[Int], + val columnNumber: Option[Int], + val typeFullName: String + ) + +} + +class CGlobal extends Global { + import io.joern.c2cpg.astcreation.CGlobal.MethodInfo + + val methodDeclarations: ConcurrentHashMap[String, MethodInfo] = new ConcurrentHashMap() + val methodDefinitions: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/Defines.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/Defines.scala index f697eb70ca23..c5f8d68554c7 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/Defines.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/Defines.scala @@ -1,10 +1,23 @@ package io.joern.c2cpg.astcreation object Defines { - val anyTypeName: String = "ANY" - val voidTypeName: String = "void" - val qualifiedNameSeparator: String = "::" - val empty = "" + val Any: String = "ANY" + val Void: String = "void" + val Function: String = "std.function" + val Array: String = "std.array" + val QualifiedNameSeparator: String = "::" + val Empty = "" + val Auto = "auto" - val operatorPointerCall = ".pointerCall" + val OperatorPointerCall = ".pointerCall" + val OperatorConstructorInitializer = ".constructorInitializer" + val OperatorTypeOf = ".typeOf" + val OperatorMax = ".max" + val OperatorMin = ".min" + val OperatorEllipses = ".op_ellipses" + val OperatorUnknown = ".unknown" + val OperatorCall = "()" + val OperatorExpressionList = ".expressionList" + val OperatorNew = ".new" + val OperatorBracketedPrimary = ".bracketedPrimary" } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/FullNameProvider.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/FullNameProvider.scala new file mode 100644 index 000000000000..6db8f39fb737 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/FullNameProvider.scala @@ -0,0 +1,446 @@ +package io.joern.c2cpg.astcreation + +import org.apache.commons.lang3.StringUtils +import org.eclipse.cdt.core.dom.ast.* +import org.eclipse.cdt.core.dom.ast.cpp.* +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIdExpression +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.EvalBinding +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFunctionDeclarator +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTFunctionDefinition +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPFunction +import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPVariable +import org.eclipse.cdt.internal.core.model.ASTStringUtil +import io.joern.x2cpg.Defines as X2CpgDefines +import io.joern.x2cpg.passes.frontend.MetaDataPass +import io.shiftleft.codepropertygraph.generated.nodes.NewMethod +import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal +import org.eclipse.cdt.internal.core.dom.parser.c.CASTFunctionDeclarator +import org.eclipse.cdt.internal.core.dom.parser.c.CVariable + +import scala.util.Try + +trait FullNameProvider { this: AstCreator => + + protected type MethodLike = IASTFunctionDeclarator | IASTFunctionDefinition | ICPPASTLambdaExpression + + protected type TypeLike = IASTEnumerationSpecifier | ICPPASTNamespaceDefinition | ICPPASTNamespaceAlias | + IASTCompositeTypeSpecifier | IASTElaboratedTypeSpecifier + + protected def fixQualifiedName(name: String): String = { + if (name.isEmpty) { name } + else { + val normalizedName = StringUtils.normalizeSpace(name) + normalizedName + .stripPrefix(Defines.QualifiedNameSeparator) + .replace(Defines.QualifiedNameSeparator, ".") + .stripPrefix(".") + } + } + + protected def isQualifiedName(name: String): Boolean = + name.startsWith(Defines.QualifiedNameSeparator) + + protected def lastNameOfQualifiedName(name: String): String = { + val normalizedName = StringUtils.normalizeSpace(replaceOperator(name)) + val cleanedName = if (normalizedName.contains("<") && normalizedName.contains(">")) { + name.substring(0, normalizedName.indexOf("<")) + } else { + normalizedName + } + cleanedName.split(Defines.QualifiedNameSeparator).lastOption.getOrElse(cleanedName) + } + + protected def methodFullNameInfo(methodLike: MethodLike): MethodFullNameInfo = { + val returnType_ = returnType(methodLike) + val signature_ = signature(returnType_, methodLike) + val name_ = shortName(methodLike) + val fullName_ = fullName(methodLike) + val sanitizedFullName = sanitizeMethodLikeFullName(name_, fullName_, signature_, methodLike) + MethodFullNameInfo(name_, sanitizedFullName, signature_, returnType_) + } + + protected def typeFullNameInfo(typeLike: TypeLike): TypeFullNameInfo = { + typeLike match { + case _: IASTElaboratedTypeSpecifier => + val name_ = shortName(typeLike) + val fullName_ = registerType(cleanType(fullName(typeLike))) + TypeFullNameInfo(name_, fullName_) + case e: IASTEnumerationSpecifier => + val name_ = shortName(e) + val fullName_ = fullName(e) + val (uniqueName_, uniqueNameFullName_) = uniqueName("enum", name_, fullName_) + TypeFullNameInfo(uniqueName_, uniqueNameFullName_) + case n: ICPPASTNamespaceDefinition => + val name_ = shortName(n) + val fullName_ = fullName(n) + val (uniqueName_, uniqueNameFullName_) = uniqueName("namespace", name_, fullName_) + TypeFullNameInfo(uniqueName_, uniqueNameFullName_) + case a: ICPPASTNamespaceAlias => + val name_ = shortName(a) + val fullName_ = fullName(a) + TypeFullNameInfo(name_, fullName_) + case s: IASTCompositeTypeSpecifier => + val fullName_ = registerType(cleanType(fullName(s))) + val name_ = shortName(s) match { + case n if n.isEmpty => lastNameOfQualifiedName(fullName_) + case other => other + } + TypeFullNameInfo(name_, fullName_) + } + } + + protected def shortName(node: IASTNode): String = { + val name = node match { + case s: IASTSimpleDeclSpecifier => s.getRawSignature + case d: IASTDeclarator => shortNameForIASTDeclarator(d) + case f: ICPPASTFunctionDefinition => shortNameForICPPASTFunctionDefinition(f) + case f: IASTFunctionDefinition => shortNameForIASTFunctionDefinition(f) + case u: IASTUnaryExpression => shortName(u.getOperand) + case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) + case d: CPPASTIdExpression => shortNameForCPPASTIdExpression(d) + case d: IASTIdExpression => shortNameForIASTIdExpression(d) + case a: ICPPASTNamespaceAlias => ASTStringUtil.getSimpleName(a.getAlias) + case n: ICPPASTNamespaceDefinition => ASTStringUtil.getSimpleName(n.getName) + case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) + case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) + case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) + case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case _: ICPPASTLambdaExpression => nextClosureName() + case other => + notHandledYet(other) + nextClosureName() + } + StringUtils.normalizeSpace(name) + } + + private def fullNameForICPPASTLambdaExpression(): String = { + methodAstParentStack + .collectFirst { + case t: NewTypeDecl => + if (t.name != NamespaceTraversal.globalNamespaceName) { + val globalFullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) + s"$globalFullName.${t.fullName}" + } else { + t.fullName + } + case m: NewMethod => + val fullNameWithoutSignature = m.fullName.stripSuffix(s":${m.signature}") + if (!m.name.startsWith("") && m.name != NamespaceTraversal.globalNamespaceName) { + val globalFullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) + s"$globalFullName.$fullNameWithoutSignature" + } else { + fullNameWithoutSignature + } + } + .mkString("") + } + + protected def fullName(node: IASTNode): String = { + fullNameFromBinding(node) match { + case Some(fullName) => + StringUtils.normalizeSpace(fullName) + case None => + val qualifiedName = node match { + case _: IASTTranslationUnit => "" + case alias: ICPPASTNamespaceAlias => fixQualifiedName(ASTStringUtil.getQualifiedName(alias.getMappingName)) + case namespace: ICPPASTNamespaceDefinition => fullNameForICPPASTNamespaceDefinition(namespace) + case compType: IASTCompositeTypeSpecifier => fullNameForIASTCompositeTypeSpecifier(compType) + case enumSpecifier: IASTEnumerationSpecifier => fullNameForIASTEnumerationSpecifier(enumSpecifier) + case f: IASTFunctionDeclarator => fullNameForIASTFunctionDeclarator(f) + case f: IASTFunctionDefinition => fullNameForIASTFunctionDefinition(f) + case e: IASTElaboratedTypeSpecifier => fullNameForIASTElaboratedTypeSpecifier(e) + case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) + case u: IASTUnaryExpression => code(u.getOperand) + case x: ICPPASTQualifiedName => fixQualifiedName(ASTStringUtil.getQualifiedName(x)) + case _: ICPPASTLambdaExpression => fullNameForICPPASTLambdaExpression() + case other if other != null && other.getParent != null => fullName(other.getParent) + case other if other != null => notHandledYet(other); "" + case null => "" + } + fixQualifiedName(qualifiedName).stripPrefix(".") + } + } + + private def isCPPFunction(methodLike: MethodLike): Boolean = { + methodLike.isInstanceOf[CPPASTFunctionDeclarator] || methodLike.isInstanceOf[CPPASTFunctionDefinition] + } + + private def sanitizeMethodLikeFullName( + name: String, + fullName: String, + signature: String, + methodLike: MethodLike + ): String = { + fullName match { + case f if methodLike.isInstanceOf[ICPPASTLambdaExpression] && (f.contains("[") || f.contains("{")) => + s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature" + case f if methodLike.isInstanceOf[ICPPASTLambdaExpression] && f.isEmpty => + s"$name:$signature" + case f if methodLike.isInstanceOf[ICPPASTLambdaExpression] => + s"$f.$name:$signature" + case f if isCPPFunction(methodLike) && (f.isEmpty || f == s"${X2CpgDefines.UnresolvedNamespace}.") => + s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature" + case f if isCPPFunction(methodLike) && f.contains("?") => + s"${StringUtils.normalizeSpace(f).takeWhile(_ != ':')}:$signature" + case f if f.isEmpty || f == s"${X2CpgDefines.UnresolvedNamespace}." => + s"${X2CpgDefines.UnresolvedNamespace}.$name" + case other if other.nonEmpty => other + case _ => s"${X2CpgDefines.UnresolvedNamespace}.$name" + } + } + + private def returnTypeForIASTFunctionDeclarator(declarator: IASTFunctionDeclarator): String = { + safeGetBinding(declarator.getName) match { + case Some(value: ICPPMethod) => + cleanType(value.getType.getReturnType.toString) + case _ if declarator.getParent.isInstanceOf[IASTSimpleDeclaration] => + cleanType(typeForDeclSpecifier(declarator.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier)) + case _ if declarator.getParent.isInstanceOf[IASTFunctionDefinition] => + cleanType(typeForDeclSpecifier(declarator.getParent.asInstanceOf[IASTFunctionDefinition].getDeclSpecifier)) + case _ => Defines.Any + } + } + + private def returnTypeForIASTFunctionDefinition(definition: IASTFunctionDefinition): String = { + if (isCppConstructor(definition)) { + typeFor(definition.asInstanceOf[CPPASTFunctionDefinition].getMemberInitializers.head.getInitializer) + } else { + safeGetBinding(definition.getDeclarator.getName) match { + case Some(value: ICPPMethod) => + cleanType(value.getType.getReturnType.toString) + case _ => + typeForDeclSpecifier(definition.getDeclSpecifier) + } + } + } + + private def returnTypeForICPPASTLambdaExpression(lambda: ICPPASTLambdaExpression): String = { + lambda.getDeclarator match { + case declarator: IASTDeclarator => + Option(declarator.getTrailingReturnType) + .map(id => typeForDeclSpecifier(id.getDeclSpecifier)) + .getOrElse(Defines.Any) + case null => + safeGetEvaluation(lambda) match { + case Some(value) if !value.toString.endsWith(": ") => cleanType(value.getType.toString) + case _ => Defines.Any + } + } + } + + private def returnType(methodLike: MethodLike): String = { + methodLike match { + case declarator: IASTFunctionDeclarator => returnTypeForIASTFunctionDeclarator(declarator) + case definition: IASTFunctionDefinition => returnTypeForIASTFunctionDefinition(definition) + case lambda: ICPPASTLambdaExpression => returnTypeForICPPASTLambdaExpression(lambda) + } + } + + private def parameterListSignature(func: IASTNode): String = { + val variadic = if (isVariadic(func)) "..." else "" + val elements = parameters(func).map { + case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) + case other => typeForDeclSpecifier(other) + } + s"(${elements.mkString(",")}$variadic)" + } + + private def signature(returnType: String, methodLike: MethodLike): String = { + StringUtils.normalizeSpace(s"$returnType${parameterListSignature(methodLike)}") + } + + private def shortNameForIASTDeclarator(declarator: IASTDeclarator): String = { + safeGetBinding(declarator.getName).map(_.getName).getOrElse { + if (ASTStringUtil.getSimpleName(declarator.getName).isEmpty && declarator.getNestedDeclarator != null) { + shortName(declarator.getNestedDeclarator) + } else { + ASTStringUtil.getSimpleName(declarator.getName) + } + } + } + + private def shortNameForICPPASTFunctionDefinition(definition: ICPPASTFunctionDefinition): String = { + if ( + ASTStringUtil.getSimpleName(definition.getDeclarator.getName).isEmpty + && definition.getDeclarator.getNestedDeclarator != null + ) { + shortName(definition.getDeclarator.getNestedDeclarator) + } else { + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(definition.getDeclarator.getName)) + } + } + + private def shortNameForIASTFunctionDefinition(definition: IASTFunctionDefinition): String = { + if ( + ASTStringUtil.getSimpleName(definition.getDeclarator.getName).isEmpty + && definition.getDeclarator.getNestedDeclarator != null + ) { + shortName(definition.getDeclarator.getNestedDeclarator) + } else { + ASTStringUtil.getSimpleName(definition.getDeclarator.getName) + } + } + + private def shortNameForCPPASTIdExpression(d: CPPASTIdExpression): String = { + val name = safeGetEvaluation(d) match { + case Some(evalBinding: EvalBinding) => + evalBinding.getBinding match { + case f: CPPFunction if f.getDeclarations != null => + f.getDeclarations.headOption.map(n => ASTStringUtil.getSimpleName(n.getName)).getOrElse(f.getName) + case f: CPPFunction if f.getDefinition != null => ASTStringUtil.getSimpleName(f.getDefinition.getName) + case other => other.getName + } + case _ => ASTStringUtil.getSimpleName(d.getName) + } + lastNameOfQualifiedName(name) + } + + private def shortNameForIASTIdExpression(d: IASTIdExpression): String = { + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) + } + + private def replaceOperator(name: String): String = { + name + .replace("operator class ", "") + .replace("operator enum ", "") + .replace("operator struct ", "") + .replace("operator ", "") + } + + private def fullNameFromBinding(node: IASTNode): Option[String] = { + node match { + case id: CPPASTIdExpression => + safeGetEvaluation(id) match { + case Some(evalBinding: EvalBinding) => + evalBinding.getBinding match { + case f: CPPFunction if f.getDeclarations != null => + Option(f.getDeclarations.headOption.map(n => s"${fullName(n)}").getOrElse(f.getName)) + case f: CPPFunction if f.getDefinition != null => + Option(s"${fullName(f.getDefinition)}") + case other => + Option(other.getName) + } + case _ => None + } + case declarator: CPPASTFunctionDeclarator => + safeGetBinding(declarator.getName) match { + case Some(function: ICPPFunction) if declarator.getName.isInstanceOf[ICPPASTConversionName] => + val tpe = cleanType(typeFor(declarator.getName.asInstanceOf[ICPPASTConversionName].getTypeId)) + val fullNameNoSig = fixQualifiedName( + function.getQualifiedName.takeWhile(!_.startsWith("operator ")).mkString(".") + ) + val fn = if (function.isExternC) { + tpe + } else { + s"$fullNameNoSig.$tpe:${functionTypeToSignature(function.getType)}" + } + Option(fn) + case Some(function: ICPPFunction) => + val fullNameNoSig = fixQualifiedName(replaceOperator(function.getQualifiedName.mkString("."))) + val fn = if (function.isExternC) { + replaceOperator(function.getName) + } else { + val returnTpe = declarator.getParent match { + case definition: ICPPASTFunctionDefinition if !isCppConstructor(definition) => returnType(definition) + case _ => safeGetType(function.getType.getReturnType) + } + val sig = signature(cleanType(returnTpe), declarator) + s"$fullNameNoSig:$sig" + } + Option(fn) + case Some(x @ (_: ICPPField | _: CPPVariable)) => + val fullNameNoSig = fixQualifiedName(x.getQualifiedName.mkString(".")) + val fn = if (x.isExternC) { + x.getName + } else { + s"$fullNameNoSig:${cleanType(safeGetType(x.getType))}" + } + Option(fn) + case Some(_: IProblemBinding) => + val fullNameNoSig = replaceOperator(ASTStringUtil.getQualifiedName(declarator.getName)) + val fixedFullName = fixQualifiedName(fullNameNoSig) + val returnTpe = declarator.getParent match { + case definition: ICPPASTFunctionDefinition if !isCppConstructor(definition) => returnType(definition) + case _ => returnType(declarator) + } + val signature_ = signature(returnTpe, declarator) + if (fixedFullName.isEmpty) { + Option(s"${X2CpgDefines.UnresolvedNamespace}:$signature_") + } else { + Option(s"$fixedFullName:$signature_") + } + case _ => None + } + case declarator: CASTFunctionDeclarator => + safeGetBinding(declarator.getName) match { + case Some(cVariable: CVariable) => Option(cVariable.getName) + case Some(cppVariable: CPPVariable) => Option(cppVariable.getName) + case _ => Option(declarator.getName.toString) + } + case definition: ICPPASTFunctionDefinition => + Some(fullName(definition.getDeclarator)) + case namespace: ICPPASTNamespaceDefinition => + safeGetBinding(namespace.getName) match { + case Some(b: ICPPBinding) if b.getName.nonEmpty => Option(b.getQualifiedName.mkString(".")) + case _ => None + } + case compType: IASTCompositeTypeSpecifier => + safeGetBinding(compType.getName) match { + case Some(b: ICPPBinding) if b.getName.nonEmpty => Option(b.getQualifiedName.mkString(".")) + case _ => None + } + case enumSpecifier: IASTEnumerationSpecifier => + safeGetBinding(enumSpecifier.getName) match { + case Some(b: ICPPBinding) if b.getName.nonEmpty => Option(b.getQualifiedName.mkString(".")) + case _ => None + } + case e: IASTElaboratedTypeSpecifier => + safeGetBinding(e.getName) match { + case Some(b: ICPPBinding) if b.getName.nonEmpty => Option(b.getQualifiedName.mkString(".")) + case _ => None + } + case _ => None + } + } + + private def fullNameForICPPASTNamespaceDefinition(namespace: ICPPASTNamespaceDefinition): String = { + s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" + } + + private def fullNameForIASTCompositeTypeSpecifier(compType: IASTCompositeTypeSpecifier): String = { + if (ASTStringUtil.getSimpleName(compType.getName).nonEmpty) { + s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" + } else { + val name = compType.getParent match { + case decl: IASTSimpleDeclaration => + decl.getDeclarators.headOption + .map(n => ASTStringUtil.getSimpleName(n.getName)) + .getOrElse(uniqueName("composite_type", "", "")._1) + case _ => uniqueName("composite_type", "", "")._1 + } + s"${fullName(compType.getParent)}.$name" + } + } + + private def fullNameForIASTEnumerationSpecifier(enumSpecifier: IASTEnumerationSpecifier): String = { + s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" + } + + private def fullNameForIASTElaboratedTypeSpecifier(e: IASTElaboratedTypeSpecifier): String = { + s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" + } + + private def fullNameForIASTFunctionDeclarator(f: IASTFunctionDeclarator): String = { + Try(fixQualifiedName(ASTStringUtil.getQualifiedName(f.getName))).getOrElse(nextClosureName()) + } + + private def fullNameForIASTFunctionDefinition(f: IASTFunctionDefinition): String = { + Try(fixQualifiedName(ASTStringUtil.getQualifiedName(f.getDeclarator.getName))).getOrElse(nextClosureName()) + } + + protected final case class MethodFullNameInfo(name: String, fullName: String, signature: String, returnType: String) + + protected final case class TypeFullNameInfo(name: String, fullName: String) + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroArgumentExtractor.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroArgumentExtractor.scala index c6ae64535fea..76e5aa4ba5c3 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroArgumentExtractor.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroArgumentExtractor.scala @@ -26,8 +26,6 @@ import scala.collection.mutable * `setExpandedMacroArgument`, we can intercept arguments and store them in a list for later retrieval. We wrap this * rather complicated way of accessing the macro arguments in the single public method `getArguments` of the * `MacroArgumentExtractor`. - * - * This class must be in this package in order to have access to `PreprocessorMacro`. */ class MacroArgumentExtractor(tu: IASTTranslationUnit, loc: IASTFileLocation) { diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroHandler.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroHandler.scala index 9001ddcd0855..f041eff3b82a 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroHandler.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/astcreation/MacroHandler.scala @@ -1,19 +1,20 @@ package io.joern.c2cpg.astcreation +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.nodes.{ - AstNodeNew, - ExpressionNew, - NewBlock, - NewCall, - NewFieldIdentifier, - NewNode -} -import io.joern.x2cpg.{Ast, AstEdge, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.AstNodeNew +import io.shiftleft.codepropertygraph.generated.nodes.ExpressionNew +import io.shiftleft.codepropertygraph.generated.nodes.NewBlock +import io.shiftleft.codepropertygraph.generated.nodes.NewCall +import io.shiftleft.codepropertygraph.generated.nodes.NewFieldIdentifier import io.shiftleft.codepropertygraph.generated.nodes.NewLocal +import io.shiftleft.codepropertygraph.generated.nodes.NewNode import org.apache.commons.lang3.StringUtils -import org.eclipse.cdt.core.dom.ast.{IASTMacroExpansionLocation, IASTNode, IASTPreprocessorMacroDefinition} import org.eclipse.cdt.core.dom.ast.IASTBinaryExpression +import org.eclipse.cdt.core.dom.ast.IASTMacroExpansionLocation +import org.eclipse.cdt.core.dom.ast.IASTNode +import org.eclipse.cdt.core.dom.ast.IASTPreprocessorMacroDefinition import org.eclipse.cdt.internal.core.model.ASTStringUtil import scala.annotation.nowarn @@ -45,18 +46,16 @@ trait MacroHandler(implicit withSchemaValidation: ValidationMode) { this: AstCre val macroCallAst = matchingMacro.map { case (mac, args) => createMacroCallAst(ast, node, mac, args) } macroCallAst match { case Some(callAst) => - val lostLocals = ast.refEdges.collect { case AstEdge(_, dst: NewLocal) => Ast(dst) }.toList - val newAst = ast.subTreeCopy(ast.root.get.asInstanceOf[AstNodeNew], argIndex = 1) - // We need to wrap the copied AST as it may contain CPG nodes not being allowed - // to be connected via AST edges under a CALL. E.g., LOCALs but only if its not already a BLOCK. - val childAst = newAst.root match { - case Some(_: NewBlock) => - newAst + // We need to wrap the AST as it may contain CPG nodes not being allowed + // to be connected via AST edges under a CALL. E.g., LOCALs but only if it is not already a BLOCK. + val childAst = ast.root match { + case Some(_: NewBlock) => ast case _ => - val b = NewBlock().argumentIndex(1).typeFullName(registerType(Defines.voidTypeName)) - blockAst(b, List(newAst)) + setArgumentIndices(List(ast)) + blockAst(blockNode(node), List(ast)) } - callAst.withChildren(lostLocals).withChild(childAst) + setArgumentIndices(List(childAst)) + callAst.withChild(childAst) case None => ast } } @@ -124,13 +123,14 @@ trait MacroHandler(implicit withSchemaValidation: ValidationMode) { this: AstCre val callName = StringUtils.normalizeSpace(name) val callFullName = StringUtils.normalizeSpace(fullName(macroDef, argAsts)) + val typeFullName = registerType(cleanType(typeFor(node))) val callNode = NewCall() .name(callName) .dispatchType(DispatchTypes.INLINED) .methodFullName(callFullName) .code(code) - .typeFullName(typeFor(node)) + .typeFullName(typeFullName) .lineNumber(line(node)) .columnNumber(column(node)) callAst(callNode, argAsts) diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CdtParser.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CdtParser.scala index 8c4f051070dd..de443d7ebe35 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CdtParser.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CdtParser.scala @@ -2,18 +2,22 @@ package io.joern.c2cpg.parser import better.files.File import io.joern.c2cpg.Config +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser.CommandObject import io.shiftleft.utils.IOUtils +import org.eclipse.cdt.core.dom.ast.IASTPreprocessorStatement +import org.eclipse.cdt.core.dom.ast.IASTTranslationUnit import org.eclipse.cdt.core.dom.ast.gnu.c.GCCLanguage import org.eclipse.cdt.core.dom.ast.gnu.cpp.GPPLanguage -import org.eclipse.cdt.core.dom.ast.{IASTPreprocessorStatement, IASTTranslationUnit} import org.eclipse.cdt.core.model.ILanguage -import org.eclipse.cdt.core.parser.{DefaultLogService, ScannerInfo} +import org.eclipse.cdt.core.parser.DefaultLogService import org.eclipse.cdt.core.parser.FileContent +import org.eclipse.cdt.core.parser.ScannerInfo import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.CPPVisitor -import org.eclipse.cdt.internal.core.parser.scanner.InternalFileContent import org.slf4j.LoggerFactory -import java.nio.file.{NoSuchFileException, Path} +import java.nio.file.NoSuchFileException +import java.nio.file.Path +import scala.collection.mutable import scala.jdk.CollectionConverters.* import scala.util.Failure import scala.util.Success @@ -30,22 +34,26 @@ object CdtParser { failure: Option[Throwable] = None ) - def readFileAsFileContent(path: Path): FileContent = { - val lines = IOUtils.readLinesInFile(path).mkString("\n").toArray - FileContent.create(path.toString, true, lines) + private def readFileAsFileContent(file: File, lines: Option[Array[Char]] = None): FileContent = { + val codeLines = lines.getOrElse(IOUtils.readLinesInFile(file.path).mkString("\n").toArray) + FileContent.create(file.pathAsString, true, codeLines) } } -class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorStatementsLogger { +class CdtParser( + config: Config, + headerFileFinder: HeaderFileFinder, + compilationDatabase: mutable.LinkedHashSet[CommandObject] +) extends ParseProblemsLogger + with PreprocessorStatementsLogger { - import io.joern.c2cpg.parser.CdtParser._ + import io.joern.c2cpg.parser.CdtParser.* - private val headerFileFinder = new HeaderFileFinder(config.inputPath) - private val parserConfig = ParserConfig.fromConfig(config) - private val definedSymbols = parserConfig.definedSymbols.asJava - private val includePaths = parserConfig.userIncludePaths - private val log = new DefaultLogService + private val parserConfig = ParserConfig.fromConfig(config, compilationDatabase) + private val definedSymbols = parserConfig.definedSymbols + private val includePaths = parserConfig.userIncludePaths + private val log = new DefaultLogService // enables parsing of code behind disabled preprocessor defines: private var opts: Int = ILanguage.OPTION_PARSE_INACTIVE_CODE @@ -55,9 +63,9 @@ class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorSta if (config.noImageLocations) opts |= ILanguage.OPTION_NO_IMAGE_LOCATIONS private def preprocessedFileIsFromCPPFile(file: Path, code: String): Boolean = { - if (config.withPreprocessedFiles && file.toString.endsWith(FileDefaults.PREPROCESSED_EXT)) { - val fileWithoutExt = file.toString.stripSuffix(FileDefaults.PREPROCESSED_EXT) - val filesWithCPPExt = FileDefaults.CPP_FILE_EXTENSIONS.map(ext => File(s"$fileWithoutExt$ext").name) + if (config.withPreprocessedFiles && FileDefaults.hasPreprocessedFileExtension(file.toString)) { + val fileWithoutExt = file.toString.substring(0, file.toString.lastIndexOf(".")) + val filesWithCPPExt = FileDefaults.CppFileExtensions.map(ext => File(s"$fileWithoutExt$ext").name) code.linesIterator.exists(line => filesWithCPPExt.exists(f => line.contains(s"\"$f\""))) } else { false @@ -65,7 +73,7 @@ class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorSta } private def createParseLanguage(file: Path, code: String): ILanguage = { - if (FileDefaults.isCPPFile(file.toString) || preprocessedFileIsFromCPPFile(file, code)) { + if (FileDefaults.hasCppFileExtension(file.toString) || preprocessedFileIsFromCPPFile(file, code)) { GPPLanguage.getDefault } else { GCCLanguage.getDefault @@ -74,9 +82,14 @@ class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorSta private def createScannerInfo(file: Path): ScannerInfo = { val additionalIncludes = - if (FileDefaults.isCPPFile(file.toString)) parserConfig.systemIncludePathsCPP + if (FileDefaults.hasCppFileExtension(file.toString)) parserConfig.systemIncludePathsCPP else parserConfig.systemIncludePathsC - new ScannerInfo(definedSymbols, (includePaths ++ additionalIncludes).map(_.toString).toArray) + val fileSpecificDefines = parserConfig.definedSymbolsPerFile.getOrElse(file.toString, Map.empty) + val fileSpecificIncludes = parserConfig.includesPerFile.getOrElse(file.toString, mutable.LinkedHashSet.empty) + new ScannerInfo( + (definedSymbols ++ fileSpecificDefines).asJava, + fileSpecificIncludes.toArray ++ (includePaths ++ additionalIncludes).map(_.toString).toArray + ) } private def parseInternal(code: String, inFile: File): IASTTranslationUnit = { @@ -91,14 +104,13 @@ class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorSta translationUnit } - private def parseInternal(file: Path): ParseResult = { - val realPath = File(file) - if (realPath.isRegularFile) { // handling potentially broken symlinks + private def parseInternal(file: File): ParseResult = { + if (file.isRegularFile) { // handling potentially broken symlinks try { - val fileContent = readFileAsFileContent(realPath.path) + val fileContent = readFileAsFileContent(file.path) val fileContentProvider = new CustomFileContentProvider(headerFileFinder) - val lang = createParseLanguage(realPath.path, fileContent.asInstanceOf[InternalFileContent].toString) - val scannerInfo = createScannerInfo(realPath.path) + val lang = createParseLanguage(file.path, fileContent.toString) + val scannerInfo = createScannerInfo(file.path) val translationUnit = lang.getASTTranslationUnit(fileContent, scannerInfo, fileContentProvider, null, opts, log) val problems = CPPVisitor.getProblems(translationUnit) if (parserConfig.logProblems) logProblems(problems.toList) @@ -119,7 +131,8 @@ class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorSta } else { ParseResult( None, - failure = Option(new NoSuchFileException(s"File '$realPath' does not exist. Check for broken symlinks!")) + failure = + Option(new NoSuchFileException(s"File '${file.pathAsString}' does not exist. Check for broken symlinks!")) ) } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CustomFileContentProvider.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CustomFileContentProvider.scala index b22a21bc4c8d..337cc3d18229 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CustomFileContentProvider.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/CustomFileContentProvider.scala @@ -1,32 +1,40 @@ package io.joern.c2cpg.parser +import io.shiftleft.utils.IOUtils import org.eclipse.cdt.core.index.IIndexFileLocation +import org.eclipse.cdt.core.parser.FileContent import org.eclipse.cdt.internal.core.parser.IMacroDictionary -import org.eclipse.cdt.internal.core.parser.scanner.{InternalFileContent, InternalFileContentProvider} +import org.eclipse.cdt.internal.core.parser.scanner.InternalFileContent +import org.eclipse.cdt.internal.core.parser.scanner.InternalFileContentProvider import org.slf4j.LoggerFactory import java.nio.file.Paths +import java.util.concurrent.ConcurrentHashMap + +object CustomFileContentProvider { + private val headerContentCache: ConcurrentHashMap[String, Array[Char]] = new ConcurrentHashMap() +} class CustomFileContentProvider(headerFileFinder: HeaderFileFinder) extends InternalFileContentProvider { + import CustomFileContentProvider.headerContentCache + private val logger = LoggerFactory.getLogger(classOf[CustomFileContentProvider]) private def loadContent(path: String): InternalFileContent = { - val maybeFullPath = if (!getInclusionExists(path)) { - headerFileFinder.find(path) - } else { - Option(path) - } - maybeFullPath - .map { foundPath => - logger.debug(s"Loading header file '$foundPath'") - CdtParser.readFileAsFileContent(Paths.get(foundPath)).asInstanceOf[InternalFileContent] - } - .getOrElse { - logger.debug(s"Cannot find header file for '$path'") - null - } - + val maybeFullPath = if (!getInclusionExists(path)) { headerFileFinder.find(path) } + else { Option(path) } + maybeFullPath.map { foundPath => + val path = Paths.get(foundPath) + val content = headerContentCache.computeIfAbsent( + foundPath, + _ => { + logger.debug(s"Loading header file '$foundPath'") + IOUtils.readLinesInFile(path).mkString("\n").toArray + } + ) + FileContent.create(path.toString, false, content).asInstanceOf[InternalFileContent] + }.orNull } override def getContentForInclusion(path: String, macroDictionary: IMacroDictionary): InternalFileContent = diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/FileDefaults.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/FileDefaults.scala index 067b47412b25..b31cac460956 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/FileDefaults.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/FileDefaults.scala @@ -1,25 +1,37 @@ package io.joern.c2cpg.parser +import org.apache.commons.lang3.StringUtils + object FileDefaults { - val C_EXT: String = ".c" - val CPP_EXT: String = ".cpp" - val PREPROCESSED_EXT: String = ".i" + val CExt: String = ".c" + val CppExt: String = ".cpp" + val PreprocessedExt: String = ".i" + + private val CHeaderFileExtensions: Set[String] = + Set(".h") + + private val CppHeaderFileExtensions: Set[String] = + Set(".hpp", ".hh", ".hp", ".hxx", ".h++", ".tcc") + + val HeaderFileExtensions: Set[String] = + CHeaderFileExtensions ++ CppHeaderFileExtensions - private val CC_EXT = ".cc" - private val C_HEADER_EXT = ".h" - private val CPP_HEADER_EXT = ".hpp" - private val OTHER_HEADER_EXT = ".hh" + private val CppSourceFileExtensions: Set[String] = + Set(".cc", ".cxx", ".cpp", ".cp", ".ccm", ".cxxm", ".c++m") - val SOURCE_FILE_EXTENSIONS: Set[String] = Set(C_EXT, CC_EXT, CPP_EXT) + val CppFileExtensions: Set[String] = + CppSourceFileExtensions ++ CppHeaderFileExtensions - val HEADER_FILE_EXTENSIONS: Set[String] = Set(C_HEADER_EXT, CPP_HEADER_EXT, OTHER_HEADER_EXT) + val SourceFileExtensions: Set[String] = + CppSourceFileExtensions ++ Set(CExt) - val CPP_FILE_EXTENSIONS: Set[String] = Set(CC_EXT, CPP_EXT, CPP_HEADER_EXT) + def hasCppFileExtension(filePath: String): Boolean = + CppFileExtensions.exists(ext => StringUtils.endsWithIgnoreCase(filePath, ext)) - def isHeaderFile(filePath: String): Boolean = - HEADER_FILE_EXTENSIONS.exists(filePath.endsWith) + def hasSourceFileExtension(filePath: String): Boolean = + SourceFileExtensions.exists(ext => StringUtils.endsWithIgnoreCase(filePath, ext)) - def isCPPFile(filePath: String): Boolean = - CPP_FILE_EXTENSIONS.exists(filePath.endsWith) + def hasPreprocessedFileExtension(filePath: String): Boolean = + StringUtils.endsWithIgnoreCase(filePath, PreprocessedExt) } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/HeaderFileFinder.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/HeaderFileFinder.scala index dbc6f36a1be9..f1a58d8cf68b 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/HeaderFileFinder.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/HeaderFileFinder.scala @@ -1,18 +1,24 @@ package io.joern.c2cpg.parser -import better.files._ +import better.files.* +import io.joern.c2cpg.C2Cpg.DefaultIgnoredFolders +import io.joern.c2cpg.Config import io.joern.x2cpg.SourceFiles import org.jline.utils.Levenshtein -import java.nio.file.Path +class HeaderFileFinder(config: Config) { -class HeaderFileFinder(root: String) { - - private val nameToPathMap: Map[String, List[Path]] = SourceFiles - .determine(root, FileDefaults.HEADER_FILE_EXTENSIONS) + private val nameToPathMap: Map[String, List[String]] = SourceFiles + .determine( + config.inputPath, + FileDefaults.HeaderFileExtensions, + ignoredDefaultRegex = Option(DefaultIgnoredFolders), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) .map { p => val file = File(p) - (file.name, file.path) + (file.name, file.pathAsString) } .groupBy(_._1) .map(x => (x._1, x._2.map(_._2))) @@ -22,7 +28,7 @@ class HeaderFileFinder(root: String) { */ def find(path: String): Option[String] = File(path).nameOption.flatMap { name => val matches = nameToPathMap.getOrElse(name, List()) - matches.map(_.toString).sortBy(x => Levenshtein.distance(x, path)).headOption + matches.sortBy(x => Levenshtein.distance(x, path)).headOption } } diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/JSONCompilationDatabaseParser.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/JSONCompilationDatabaseParser.scala new file mode 100644 index 000000000000..a1e5f750f5a3 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/JSONCompilationDatabaseParser.scala @@ -0,0 +1,106 @@ +package io.joern.c2cpg.parser + +import io.joern.x2cpg.SourceFiles +import io.shiftleft.utils.IOUtils +import org.slf4j.LoggerFactory +import ujson.Value + +import java.nio.file.Paths +import scala.collection.mutable +import scala.util.Try + +object JSONCompilationDatabaseParser { + + private val logger = LoggerFactory.getLogger(getClass) + + /** {{{ + * 1) -D: Matches the -D flag, which is the key prefix for defining macros. + * 2) ([A-Za-z_][A-Za-z0-9_]+): Matches a valid macro name (which must start with a letter or underscore and can be followed by letters, numbers, or underscores). + * 3) (=(\\*".*"))?: Optionally matches = followed by either: + * a) A quoted string: Allows for strings in quotes. + * b) Any char sequence (.*") closed with a quote. + * }}} + */ + private val defineInCommandPattern = """-D([A-Za-z_][A-Za-z0-9_]+)(=(\\*".*"))?""".r + + /** {{{ + * 1) -I: Matches the -I flag, which indicates an include directory. + * 2) (\S+): Matches one or more non-whitespace characters, which represent the path of the directory. + * }}} + */ + private val includeInCommandPattern = """-I(\S+)""".r + + case class CommandObject( + directory: String, + arguments: mutable.LinkedHashSet[String], + command: mutable.LinkedHashSet[String], + file: String + ) { + + /** @return + * the file path (guaranteed to be absolute) + */ + def compiledFile(): String = SourceFiles.toAbsolutePath(file, directory) + + private def nameValuePairFromDefine(define: String): (String, String) = { + val s = define.stripPrefix("-D") + if (s.contains("=")) { + val split = s.split("=") + (split.head, split(1)) + } else { + (s, "") + } + } + + private def pathFromInclude(include: String): String = include.stripPrefix("-I") + + def includes(): mutable.LinkedHashSet[String] = { + val includesFromArguments = arguments.filter(a => a.startsWith("-I")).map(pathFromInclude) + val includesFromCommand = command.flatMap { c => + val includes = includeInCommandPattern.findAllIn(c).toList + includes.map(pathFromInclude) + } + includesFromArguments ++ includesFromCommand + } + + def defines(): mutable.LinkedHashSet[(String, String)] = { + val definesFromArguments = arguments.filter(a => a.startsWith("-D")).map(nameValuePairFromDefine) + val definesFromCommand = command.flatMap { c => + val defines = defineInCommandPattern.findAllIn(c).toList + defines.map(nameValuePairFromDefine) + } + definesFromArguments ++ definesFromCommand + } + } + + private def hasKey(node: Value, key: String): Boolean = Try(node(key)).isSuccess + + private def safeArguments(obj: Value): mutable.LinkedHashSet[String] = { + if (hasKey(obj, "arguments")) + obj("arguments").arrOpt + .map(arr => mutable.LinkedHashSet.from(arr).map(_.str)) + .getOrElse(mutable.LinkedHashSet.empty) + else mutable.LinkedHashSet.empty + } + + private def safeCommand(obj: Value): mutable.LinkedHashSet[String] = { + if (hasKey(obj, "command")) mutable.LinkedHashSet.empty.addOne(obj("command").str) + else mutable.LinkedHashSet.empty + } + + def parse(compileCommandsJson: String): mutable.LinkedHashSet[CommandObject] = { + try { + val jsonContent = IOUtils.readEntireFile(Paths.get(compileCommandsJson)) + val json = ujson.read(jsonContent) + val allCommandObjects = mutable.LinkedHashSet.from(json.arr) + allCommandObjects.map { obj => + CommandObject(obj("directory").str, safeArguments(obj), safeCommand(obj), obj("file").str) + } + } catch { + case t: Throwable => + logger.warn(s"Could not parse '$compileCommandsJson'", t) + mutable.LinkedHashSet.empty + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/ParserConfig.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/ParserConfig.scala index 7bb82a6b751e..93bce7b9df21 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/ParserConfig.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/parser/ParserConfig.scala @@ -1,36 +1,61 @@ package io.joern.c2cpg.parser import io.joern.c2cpg.Config +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser.CommandObject import io.joern.c2cpg.utils.IncludeAutoDiscovery import java.nio.file.{Path, Paths} +import scala.collection.mutable object ParserConfig { def empty: ParserConfig = - ParserConfig(Set.empty, Set.empty, Set.empty, Map.empty, logProblems = false, logPreprocessor = false) + ParserConfig( + mutable.LinkedHashSet.empty, + mutable.LinkedHashSet.empty, + mutable.LinkedHashSet.empty, + Map.empty, + Map.empty, + Map.empty, + logProblems = false, + logPreprocessor = false + ) - def fromConfig(config: Config): ParserConfig = ParserConfig( - config.includePaths.map(Paths.get(_).toAbsolutePath), - IncludeAutoDiscovery.discoverIncludePathsC(config), - IncludeAutoDiscovery.discoverIncludePathsCPP(config), - config.defines.map { - case define if define.contains("=") => - val s = define.split("=") - s.head -> s(1) - case define => define -> "true" - }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, - config.logProblems, - config.logPreprocessor - ) + def fromConfig(config: Config, compilationDatabase: mutable.LinkedHashSet[CommandObject]): ParserConfig = { + val compilationDatabaseDefines = compilationDatabase.map { c => + c.compiledFile() -> c.defines().toMap + }.toMap + val includes = compilationDatabase.map { c => + c.compiledFile() -> c.includes() + }.toMap + ParserConfig( + mutable.LinkedHashSet.from(config.includePaths.map(Paths.get(_).toAbsolutePath)), + IncludeAutoDiscovery.discoverIncludePathsC(config), + IncludeAutoDiscovery.discoverIncludePathsCPP(config), + config.defines.map { define => + if (define.contains("=")) { + val split = define.split("=") + split.head -> split(1) + } else { + define -> "" + } + }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, + compilationDatabaseDefines, + includes, + config.logProblems, + config.logPreprocessor + ) + } } case class ParserConfig( - userIncludePaths: Set[Path], - systemIncludePathsC: Set[Path], - systemIncludePathsCPP: Set[Path], + userIncludePaths: mutable.LinkedHashSet[Path], + systemIncludePathsC: mutable.LinkedHashSet[Path], + systemIncludePathsCPP: mutable.LinkedHashSet[Path], definedSymbols: Map[String, String], + definedSymbolsPerFile: Map[String, Map[String, String]], + includesPerFile: Map[String, mutable.LinkedHashSet[String]], logProblems: Boolean, logPreprocessor: Boolean ) diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/AstCreationPass.scala index a140db208fe4..73ab6d778ad2 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/AstCreationPass.scala @@ -3,34 +3,54 @@ package io.joern.c2cpg.passes import io.joern.c2cpg.C2Cpg.DefaultIgnoredFolders import io.joern.c2cpg.Config import io.joern.c2cpg.astcreation.AstCreator -import io.joern.c2cpg.astcreation.Defines -import io.joern.c2cpg.parser.{CdtParser, FileDefaults} -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.joern.c2cpg.astcreation.CGlobal +import io.joern.c2cpg.parser.CdtParser +import io.joern.c2cpg.parser.FileDefaults +import io.joern.c2cpg.parser.HeaderFileFinder +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser.CommandObject import io.joern.x2cpg.SourceFiles -import io.joern.x2cpg.datastructures.Global import io.joern.x2cpg.utils.Report import io.joern.x2cpg.utils.TimeUtils +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.passes.ForkJoinParallelCpgPass +import org.slf4j.Logger +import org.slf4j.LoggerFactory import java.nio.file.Paths import java.util.concurrent.ConcurrentHashMap -import scala.util.matching.Regex +import scala.collection.mutable import scala.jdk.CollectionConverters.* +import scala.util.Failure +import scala.util.Success +import scala.util.Try +import scala.util.matching.Regex class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) extends ForkJoinParallelCpgPass[String](cpg) { + private val logger: Logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + + private val global = new CGlobal() + private val headerFileFinder = new HeaderFileFinder(config) private val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] = new ConcurrentHashMap() - private val parser: CdtParser = new CdtParser(config) - private val global = new Global() + private val compilationDatabase: mutable.LinkedHashSet[CommandObject] = + config.compilationDatabase.map(JSONCompilationDatabaseParser.parse).getOrElse(mutable.LinkedHashSet.empty) - def typesSeen(): List[String] = global.usedTypes.keys().asScala.filterNot(_ == Defines.anyTypeName).toList + private val parser: CdtParser = new CdtParser(config, headerFileFinder, compilationDatabase) - override def generateParts(): Array[String] = { - val sourceFileExtensions = FileDefaults.SOURCE_FILE_EXTENSIONS - ++ FileDefaults.HEADER_FILE_EXTENSIONS - ++ Option.when(config.withPreprocessedFiles)(FileDefaults.PREPROCESSED_EXT).toList + def typesSeen(): List[String] = global.usedTypes.keys().asScala.toList + + def unhandledMethodDeclarations(): Map[String, CGlobal.MethodInfo] = { + global.methodDeclarations.asScala.toMap -- global.methodDefinitions.asScala.keys + } + + private def sourceFilesFromDirectory(): Array[String] = { + val sourceFileExtensions = + FileDefaults.SourceFileExtensions ++ + FileDefaults.HeaderFileExtensions ++ + Option.when(config.withPreprocessedFiles)(FileDefaults.PreprocessedExt).toList val allSourceFiles = SourceFiles .determine( config.inputPath, @@ -42,8 +62,8 @@ class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) .toArray if (config.withPreprocessedFiles) { allSourceFiles.filter { - case f if !f.endsWith(FileDefaults.PREPROCESSED_EXT) => - val fAsPreprocessedFile = s"${f.substring(0, f.lastIndexOf("."))}${FileDefaults.PREPROCESSED_EXT}" + case f if !FileDefaults.hasPreprocessedFileExtension(f) => + val fAsPreprocessedFile = s"${f.substring(0, f.lastIndexOf("."))}${FileDefaults.PreprocessedExt}" !allSourceFiles.exists { sourceFile => f != sourceFile && sourceFile == fAsPreprocessedFile } case _ => true } @@ -52,6 +72,29 @@ class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) } } + private def sourceFilesFromCompilationDatabase(compilationDatabaseFile: String): Array[String] = { + if (compilationDatabase.isEmpty) { + logger.warn(s"'$compilationDatabaseFile' contains no source files. CPG will be empty.") + } + SourceFiles + .filterFiles( + compilationDatabase.map(_.compiledFile()).toList, + config.inputPath, + ignoredDefaultRegex = Option(DefaultIgnoredFolders), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + .toArray + } + + override def generateParts(): Array[String] = { + if (config.compilationDatabase.isEmpty) { + sourceFilesFromDirectory() + } else { + sourceFilesFromCompilationDatabase(config.compilationDatabase.get) + } + } + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = { val path = Paths.get(filename).toAbsolutePath val relPath = SourceFiles.toRelativePath(path.toString, config.inputPath) @@ -61,11 +104,18 @@ class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) parseResult match { case Some(translationUnit) => report.addReportInfo(relPath, fileLOC, parsed = true) - val localDiff = new AstCreator(relPath, global, config, translationUnit, file2OffsetTable)( - config.schemaValidation - ).createAst() - diffGraph.absorb(localDiff) - true + Try { + val localDiff = + new AstCreator(relPath, global, config, translationUnit, headerFileFinder, file2OffsetTable)( + config.schemaValidation + ).createAst() + diffGraph.absorb(localDiff) + } match { + case Failure(exception) => + logger.warn(s"Failed to generate a CPG for: '$filename'", exception) + false + case Success(_) => true + } case None => report.addReportInfo(relPath, fileLOC) false diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/FunctionDeclNodePass.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/FunctionDeclNodePass.scala new file mode 100644 index 000000000000..ceba5ba84df0 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/FunctionDeclNodePass.scala @@ -0,0 +1,174 @@ +package io.joern.c2cpg.passes + +import io.joern.c2cpg.astcreation.CGlobal +import io.joern.x2cpg.Ast +import io.joern.x2cpg.Defines +import io.joern.x2cpg.ValidationMode +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.NewBlock +import io.shiftleft.codepropertygraph.generated.nodes.NewMethod +import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn +import io.shiftleft.codepropertygraph.generated.nodes.NewMethodReturn +import io.shiftleft.codepropertygraph.generated.EvaluationStrategies +import io.shiftleft.codepropertygraph.generated.nodes.NewBinding +import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.NodeTypes +import io.shiftleft.codepropertygraph.generated.nodes.NewModifier +import io.shiftleft.passes.CpgPass +import io.shiftleft.semanticcpg.language.* +import org.apache.commons.lang3.StringUtils + +import scala.collection.immutable.Map + +class FunctionDeclNodePass(cpg: Cpg, methodDeclarations: Map[String, CGlobal.MethodInfo])(implicit + withSchemaValidation: ValidationMode +) extends CpgPass(cpg) { + + private def methodNode(fullName: String, methodNodeInfo: CGlobal.MethodInfo): NewMethod = { + val node_ = + NewMethod() + .name(StringUtils.normalizeSpace(methodNodeInfo.name)) + .code(methodNodeInfo.code) + .fullName(StringUtils.normalizeSpace(fullName)) + .filename(methodNodeInfo.fileName) + .astParentType(methodNodeInfo.astParentType) + .astParentFullName(methodNodeInfo.astParentFullName) + .isExternal(false) + .lineNumber(methodNodeInfo.lineNumber) + .columnNumber(methodNodeInfo.columnNumber) + .lineNumberEnd(methodNodeInfo.lineNumberEnd) + .columnNumberEnd(methodNodeInfo.columnNumberEnd) + .signature(StringUtils.normalizeSpace(methodNodeInfo.signature)) + methodNodeInfo.offset.foreach { case (offset, offsetEnd) => + node_.offset(offset).offsetEnd(offsetEnd) + } + node_ + } + + private def parameterInNode(parameterNodeInfo: CGlobal.ParameterInfo): NewMethodParameterIn = { + NewMethodParameterIn() + .name(parameterNodeInfo.name) + .code(parameterNodeInfo.code) + .index(parameterNodeInfo.index) + .order(parameterNodeInfo.index) + .isVariadic(parameterNodeInfo.isVariadic) + .evaluationStrategy(parameterNodeInfo.evaluationStrategy) + .lineNumber(parameterNodeInfo.lineNumber) + .columnNumber(parameterNodeInfo.columnNumber) + .typeFullName(parameterNodeInfo.typeFullName) + } + + private def methodReturnNode(typeFullName: String, line: Option[Int], column: Option[Int]): NewMethodReturn = + NewMethodReturn() + .typeFullName(typeFullName) + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .lineNumber(line) + .columnNumber(column) + + private def typeDeclNode( + name: String, + fullName: String, + filename: String, + code: String, + astParentType: String, + astParentFullName: String, + line: Option[Int], + column: Option[Int], + offset: Option[(Int, Int)] + ): NewTypeDecl = { + val node_ = NewTypeDecl() + .name(name) + .fullName(fullName) + .code(code) + .isExternal(false) + .filename(filename) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .lineNumber(line) + .columnNumber(column) + offset.foreach { case (offset, offsetEnd) => + node_.offset(offset).offsetEnd(offsetEnd) + } + node_ + } + + private def methodStubAst( + method: NewMethod, + parameters: Seq[Ast], + methodReturn: NewMethodReturn, + modifier: Seq[Ast] + ): Ast = + Ast(method) + .withChildren(parameters) + .withChild(Ast(NewBlock().typeFullName(Defines.Any))) + .withChildren(modifier) + .withChild(Ast(methodReturn)) + + private def createFunctionTypeAndTypeDecl( + methodInfo: CGlobal.MethodInfo, + method: NewMethod, + methodName: String, + methodFullName: String, + signature: String, + dstGraph: DiffGraphBuilder + ): Ast = { + val normalizedName = StringUtils.normalizeSpace(methodName) + val normalizedFullName = StringUtils.normalizeSpace(methodFullName) + + if (methodInfo.astParentType == NodeTypes.TYPE_DECL) { + val parentTypeDecl = cpg.typeDecl.nameExact(methodInfo.astParentFullName).headOption + parentTypeDecl + .map { typeDecl => + val functionBinding = + NewBinding().name(normalizedName).methodFullName(normalizedFullName).signature(signature) + dstGraph.addEdge(typeDecl, functionBinding, EdgeTypes.BINDS) + Ast(functionBinding).withRefEdge(functionBinding, method) + } + .getOrElse(Ast()) + } else { + val typeDecl = typeDeclNode( + normalizedName, + normalizedFullName, + method.filename, + normalizedName, + methodInfo.astParentType, + methodInfo.astParentFullName, + methodInfo.lineNumber, + methodInfo.columnNumber, + methodInfo.offset + ) + Ast.storeInDiffGraph(Ast(typeDecl), dstGraph) + method.astParentFullName = typeDecl.fullName + method.astParentType = typeDecl.label + val functionBinding = NewBinding().name(normalizedName).methodFullName(normalizedFullName).signature(signature) + Ast(functionBinding).withBindsEdge(typeDecl, functionBinding).withRefEdge(functionBinding, method) + } + } + + override def run(dstGraph: DiffGraphBuilder): Unit = { + methodDeclarations.foreach { case (fullName, methodNodeInfo) => + val methodNode_ = methodNode(fullName, methodNodeInfo) + val parameterNodes = methodNodeInfo.parameter.map(p => Ast(parameterInNode(p))) + val stubAst = + methodStubAst( + methodNode_, + parameterNodes, + methodReturnNode(methodNodeInfo.returnType, methodNodeInfo.lineNumber, methodNodeInfo.columnNumber), + methodNodeInfo.modifier.map(m => Ast(NewModifier().modifierType(m))) + ) + val typeDeclAst = createFunctionTypeAndTypeDecl( + methodNodeInfo, + methodNode_, + methodNodeInfo.name, + fullName, + methodNodeInfo.signature, + dstGraph + ) + val ast = stubAst.merge(typeDeclAst) + Ast.storeInDiffGraph(ast, dstGraph) + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/PreprocessorPass.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/PreprocessorPass.scala index 3a884d7a9257..f60964bd4ba1 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/PreprocessorPass.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/PreprocessorPass.scala @@ -3,32 +3,69 @@ package io.joern.c2cpg.passes import io.joern.c2cpg.C2Cpg.DefaultIgnoredFolders import io.joern.c2cpg.Config import io.joern.c2cpg.parser.{CdtParser, FileDefaults} +import io.joern.c2cpg.parser.HeaderFileFinder +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser.CommandObject import io.joern.x2cpg.SourceFiles import org.eclipse.cdt.core.dom.ast.{ - IASTPreprocessorIfStatement, IASTPreprocessorIfdefStatement, + IASTPreprocessorIfStatement, IASTPreprocessorStatement } +import org.slf4j.LoggerFactory import java.nio.file.Paths +import scala.collection.mutable import scala.collection.parallel.CollectionConverters.ImmutableIterableIsParallelizable import scala.collection.parallel.immutable.ParIterable class PreprocessorPass(config: Config) { - private val parser = new CdtParser(config) + private val logger = LoggerFactory.getLogger(classOf[PreprocessorPass]) + + private val compilationDatabase: mutable.LinkedHashSet[CommandObject] = + config.compilationDatabase.map(JSONCompilationDatabaseParser.parse).getOrElse(mutable.LinkedHashSet.empty) - def run(): ParIterable[String] = + private val headerFileFinder = new HeaderFileFinder(config) + private val parser = new CdtParser(config, headerFileFinder, compilationDatabase) + + private def sourceFilesFromDirectory(): ParIterable[String] = { SourceFiles .determine( config.inputPath, - FileDefaults.SOURCE_FILE_EXTENSIONS, + FileDefaults.SourceFileExtensions, + ignoredDefaultRegex = Option(DefaultIgnoredFolders), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + .par + .flatMap(runOnPart) + } + + private def sourceFilesFromCompilationDatabase(compilationDatabaseFile: String): ParIterable[String] = { + if (compilationDatabase.isEmpty) { + logger.warn(s"'$compilationDatabaseFile' contains no source files.") + } + SourceFiles + .filterFiles( + compilationDatabase.map(_.compiledFile()).toList, + config.inputPath, ignoredDefaultRegex = Option(DefaultIgnoredFolders), ignoredFilesRegex = Option(config.ignoredFilesRegex), ignoredFilesPath = Option(config.ignoredFiles) ) .par .flatMap(runOnPart) + } + + def run(): ParIterable[String] = { + if (config.compilationDatabase.isEmpty) { + sourceFilesFromDirectory() + } else { + sourceFilesFromCompilationDatabase(config.compilationDatabase.get) + } + + } private def preprocessorStatement2String(stmt: IASTPreprocessorStatement): Option[String] = stmt match { case s: IASTPreprocessorIfStatement => diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/TypeDeclNodePass.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/TypeDeclNodePass.scala index f8437a16004b..f4dc1ead9293 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/TypeDeclNodePass.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/passes/TypeDeclNodePass.scala @@ -34,8 +34,8 @@ class TypeDeclNodePass(cpg: Cpg)(implicit withSchemaValidation: ValidationMode) .lineNumber(1) .astParentType(NodeTypes.NAMESPACE_BLOCK) .astParentFullName(fullName) - val blockNode = NewBlock().typeFullName(Defines.anyTypeName) - val methodReturn = newMethodReturnNode(Defines.anyTypeName, line = None, column = None) + val blockNode = NewBlock().typeFullName(Defines.Any) + val methodReturn = newMethodReturnNode(Defines.Any, line = None, column = None) Ast(includesFile).withChild( Ast(namespaceBlock) .withChild(Ast(fakeGlobalIncludesMethod).withChild(Ast(blockNode)).withChild(Ast(methodReturn))) diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/ExternalCommand.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/ExternalCommand.scala deleted file mode 100644 index c213b6e8fe71..000000000000 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/ExternalCommand.scala +++ /dev/null @@ -1,33 +0,0 @@ -package io.joern.c2cpg.utils - -import java.util.concurrent.ConcurrentLinkedQueue -import scala.sys.process.{Process, ProcessLogger} -import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters.* - -object ExternalCommand extends io.joern.x2cpg.utils.ExternalCommand { - - override def handleRunResult(result: Try[Int], stdOut: Seq[String], stdErr: Seq[String]): Try[Seq[String]] = { - result match { - case Success(0) => - Success(stdOut) - case Success(1) if IsWin && IncludeAutoDiscovery.gccAvailable() => - // the command to query the system header file locations within a Windows - // environment always returns Success(1) for whatever reason... - Success(stdOut) - case _ => - Failure(new RuntimeException(stdOut.mkString(System.lineSeparator()))) - } - } - - override def run(command: String, cwd: String, extraEnv: Map[String, String] = Map.empty): Try[Seq[String]] = { - val stdOutOutput = new ConcurrentLinkedQueue[String] - val processLogger = ProcessLogger(stdOutOutput.add, stdOutOutput.add) - val process = shellPrefix match { - case Nil => Process(command, new java.io.File(cwd), extraEnv.toList*) - case _ => Process(shellPrefix :+ command, new java.io.File(cwd), extraEnv.toList*) - } - handleRunResult(Try(process.!(processLogger)), stdOutOutput.asScala.toSeq, Nil) - } - -} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/GccSpecificExternalCommand.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/GccSpecificExternalCommand.scala new file mode 100644 index 000000000000..761a39549108 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/GccSpecificExternalCommand.scala @@ -0,0 +1,26 @@ +package io.joern.c2cpg.utils + +import io.joern.x2cpg.utils.ExternalCommand + +import scala.util.{Failure, Success, Try} + +object GccSpecificExternalCommand { + + import ExternalCommand.ExternalCommandResult + + private val IsWin = scala.util.Properties.isWin + + def run(command: Seq[String], cwd: String, extraEnv: Map[String, String] = Map.empty): Try[Seq[String]] = { + ExternalCommand.run(command, cwd, mergeStdErrInStdOut = true, extraEnv) match { + case ExternalCommandResult(0, stdOut, _) => + Success(stdOut) + case ExternalCommandResult(1, stdOut, _) if IsWin && IncludeAutoDiscovery.gccAvailable() => + // the command to query the system header file locations within a Windows + // environment always returns Success(1) for whatever reason... + Success(stdOut) + case ExternalCommandResult(_, stdOut, _) => + Failure(new RuntimeException(stdOut.mkString(System.lineSeparator()))) + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/IncludeAutoDiscovery.scala b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/IncludeAutoDiscovery.scala index 131434e29564..27aaba45910a 100644 --- a/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/IncludeAutoDiscovery.scala +++ b/joern-cli/frontends/c2cpg/src/main/scala/io/joern/c2cpg/utils/IncludeAutoDiscovery.scala @@ -1,9 +1,13 @@ package io.joern.c2cpg.utils +import better.files.File import io.joern.c2cpg.Config +import io.joern.x2cpg.utils.ExternalCommand import org.slf4j.LoggerFactory -import java.nio.file.{Path, Paths} +import java.nio.file.Path +import java.nio.file.Paths +import scala.collection.mutable import scala.util.Failure import scala.util.Success @@ -11,26 +15,35 @@ object IncludeAutoDiscovery { private val logger = LoggerFactory.getLogger(IncludeAutoDiscovery.getClass) - private val IS_WIN = scala.util.Properties.isWin + private val IsWin = scala.util.Properties.isWin - val GCC_VERSION_COMMAND = "gcc --version" + private val GccVersionCommand = Seq("gcc", "--version") - private val CPP_INCLUDE_COMMAND = - if (IS_WIN) "gcc -xc++ -E -v . -o nul" else "gcc -xc++ -E -v /dev/null -o /dev/null" + private val CppIncludeCommand = + if (IsWin) Seq("gcc", "-xc++", "-E", "-v", ".", "-o", "nul") + else Seq("gcc", "-xc++", "-E", "-v", "/dev/null", "-o", "/dev/null") - private val C_INCLUDE_COMMAND = - if (IS_WIN) "gcc -xc -E -v . -o nul" else "gcc -xc -E -v /dev/null -o /dev/null" + private val CIncludeCommand = + if (IsWin) Seq("gcc", "-xc", "-E", "-v", ".", "-o", "nul") + else Seq("gcc", "-xc", "-E", "-v", "/dev/null", "-o", "/dev/null") + + private val VsWhereCommand = Seq( + "cmd.exe", + "/C", + "\"%ProgramFiles(x86)%\\Microsoft Visual Studio\\Installer\\vswhere.exe\" -property installationPath" + ) + + private val VcVarsCommand = Seq("cmd.exe", "/C", "VC\\Auxiliary\\Build\\vcvars64.bat") // Only check once private var isGccAvailable: Option[Boolean] = None // Only discover them once - private var systemIncludePathsC: Set[Path] = Set.empty - private var systemIncludePathsCPP: Set[Path] = Set.empty + private var systemIncludePathsC: mutable.LinkedHashSet[Path] = mutable.LinkedHashSet.empty + private var systemIncludePathsCPP: mutable.LinkedHashSet[Path] = mutable.LinkedHashSet.empty private def checkForGcc(): Boolean = { - logger.debug("Checking gcc ...") - ExternalCommand.run(GCC_VERSION_COMMAND, ".") match { + ExternalCommand.run(GccVersionCommand, ".").toTry match { case Success(result) => logger.debug(s"GCC is available: ${result.mkString(System.lineSeparator())}") true @@ -48,52 +61,96 @@ object IncludeAutoDiscovery { isGccAvailable.get } - private def extractPaths(output: Seq[String]): Set[Path] = { + private def extractPaths(output: Seq[String]): mutable.LinkedHashSet[Path] = { val startIndex = output.indexWhere(_.contains("#include")) + 2 val endIndex = - if (IS_WIN) output.indexWhere(_.startsWith("End of search list.")) - 1 + if (IsWin) output.indexWhere(_.startsWith("End of search list.")) - 1 else output.indexWhere(_.startsWith("COMPILER_PATH")) - 1 - output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath()).toSet + mutable.LinkedHashSet.from(output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath())) + } + + private def discoverPaths(command: Seq[String]): mutable.LinkedHashSet[Path] = + GccSpecificExternalCommand.run(command, ".") match { + case Success(output) => extractPaths(output) + case Failure(exception) => + logger.warn(s"Unable to discover system include paths. Running '$command' failed.", exception) + mutable.LinkedHashSet.empty + } + + private def discoverMSVCInstallPath(): Option[String] = { + GccSpecificExternalCommand.run(VsWhereCommand, ".") match { + case Success(output) => + output.headOption + case Failure(exception) => + logger.warn(s"Unable to discover MSVC installation path.", exception) + None + } + } + + private def extractMSVCIncludePaths(resolvedInstallationPath: String): mutable.LinkedHashSet[Path] = { + GccSpecificExternalCommand.run(VcVarsCommand, resolvedInstallationPath, Map("VSCMD_DEBUG" -> "3")) match { + case Success(results) => + results.find(_.startsWith("INCLUDE=")) match { + case Some(includesLine) => + val includesString = includesLine.replaceFirst("INCLUDE=", "") + mutable.LinkedHashSet.from(includesString.split(";").map(p => Paths.get(p.trim).toRealPath())) + case None => mutable.LinkedHashSet.empty + } + case Failure(exception) => + logger.warn(s"Unable to discover MSVC system include paths.", exception) + mutable.LinkedHashSet.empty + } } - private def discoverPaths(command: String): Set[Path] = ExternalCommand.run(command, ".") match { - case Success(output) => extractPaths(output) - case Failure(exception) => - logger.warn(s"Unable to discover system include paths. Running '$command' failed.", exception) - Set.empty + private def discoverMSVCPaths(): mutable.LinkedHashSet[Path] = { + discoverMSVCInstallPath().map(extractMSVCIncludePaths).getOrElse(mutable.LinkedHashSet.empty) } - def discoverIncludePathsC(config: Config): Set[Path] = { - if (config.includePathsAutoDiscovery && systemIncludePathsC.nonEmpty) { - systemIncludePathsC - } else if (config.includePathsAutoDiscovery && systemIncludePathsC.isEmpty && gccAvailable()) { - val includePathsC = discoverPaths(C_INCLUDE_COMMAND) - if (includePathsC.nonEmpty) { - logger.info(s"Using the following C system include paths:${includePathsC - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}") - } - systemIncludePathsC = includePathsC - includePathsC - } else { - Set.empty + private def reportIncludePaths(paths: mutable.LinkedHashSet[Path], lang: String): Unit = { + if (paths.nonEmpty) { + val ls = System.lineSeparator() + logger.info(s"Using the following $lang system include paths:${paths.mkString(s"$ls- ", s"$ls- ", ls)}") } } - def discoverIncludePathsCPP(config: Config): Set[Path] = { - if (config.includePathsAutoDiscovery && systemIncludePathsCPP.nonEmpty) { - systemIncludePathsCPP - } else if (config.includePathsAutoDiscovery && systemIncludePathsCPP.isEmpty && gccAvailable()) { - val includePathsCPP = discoverPaths(CPP_INCLUDE_COMMAND) - if (includePathsCPP.nonEmpty) { - logger.info(s"Using the following CPP system include paths:${includePathsCPP - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}") - } - systemIncludePathsCPP = includePathsCPP - includePathsCPP - } else { - Set.empty + def discoverIncludePathsC(config: Config): mutable.LinkedHashSet[Path] = { + if (!config.includePathsAutoDiscovery) return mutable.LinkedHashSet.empty + if (systemIncludePathsC.nonEmpty) return systemIncludePathsC + + if (isMSVCProject(config)) { + systemIncludePathsCPP = discoverMSVCPaths() // discovers paths for both languages + systemIncludePathsC = systemIncludePathsCPP + reportIncludePaths(systemIncludePathsC, "MSVC") + } + if (systemIncludePathsC.isEmpty && gccAvailable()) { + systemIncludePathsC = discoverPaths(CIncludeCommand) + reportIncludePaths(systemIncludePathsC, "C") + } + systemIncludePathsC + } + + private def isMSVCProject(config: Config): Boolean = { + if (!IsWin) return false + val projectDir = File(config.inputPath) + List(projectDir / ".vs", projectDir / ".vscode").exists(_.exists) || + projectDir.list.exists(_.`extension`(includeDot = false).exists(ext => ext == "sln" || ext == "vcxproj")) + } + + def discoverIncludePathsCPP(config: Config): mutable.LinkedHashSet[Path] = { + if (!config.includePathsAutoDiscovery) return mutable.LinkedHashSet.empty + if (systemIncludePathsCPP.nonEmpty) return systemIncludePathsCPP + + if (isMSVCProject(config)) { + systemIncludePathsCPP = discoverMSVCPaths() // discovers paths for both languages + systemIncludePathsC = systemIncludePathsCPP + reportIncludePaths(systemIncludePathsCPP, "MSVC") + } + if (systemIncludePathsCPP.isEmpty && gccAvailable()) { + systemIncludePathsCPP = discoverPaths(CppIncludeCommand) + reportIncludePaths(systemIncludePathsCPP, "CPP") } + systemIncludePathsCPP } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/config/ConfigTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/config/ConfigTests.scala index 2a73c8c8c317..7e2d3c665401 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features17/Cpp17FeaturesTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features17/Cpp17FeaturesTests.scala new file mode 100644 index 000000000000..a9d076beb9eb --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features17/Cpp17FeaturesTests.scala @@ -0,0 +1,510 @@ +package io.joern.c2cpg.cpp.features17 + +import io.joern.c2cpg.astcreation.Defines +import io.joern.c2cpg.parser.FileDefaults +import io.joern.c2cpg.testfixtures.AstC2CpgSuite +import io.shiftleft.codepropertygraph.generated.ControlStructureTypes +import io.shiftleft.semanticcpg.language.* +import org.apache.commons.lang3.StringUtils + +class Cpp17FeaturesTests extends AstC2CpgSuite(fileSuffix = FileDefaults.CppExt) { + + "C++17 feature support" should { + + "handle template argument deduction for class templates" in { + val cpg = code(""" + |template + |struct MyContainer { + | T val; + | MyContainer() : val{} {} + | MyContainer(T val) : val{val} {} + | // ... + |}; + |MyContainer c1 {1}; // OK MyContainer + |MyContainer c2; // OK MyContainer + |""".stripMargin) + val List(c1, c2) = cpg.local.l + c1.name shouldBe "c1" + c1.typeFullName shouldBe "MyContainer" + c2.name shouldBe "c2" + c2.typeFullName shouldBe "MyContainer" + // We are unable to express this template argument deduction in the current schema + cpg.typeDecl.member.nameExact("val").typeFullName.l shouldBe List("T") + } + + "handle declaring non-type template parameters with auto" in { + val cpg = code(""" + |template + |struct integer_sequence { + | using value_type = T; + | static constexpr std::size_t size() noexcept { return sizeof...(Ints); } + |}; + | + |template + |struct my_integer_sequence { + | // Implementation here ... + |}; + | + |// Explicitly pass type `int` as template argument. + |auto seq = integer_sequence(); + |// Type is deduced to be `int`. + |auto seq2 = my_integer_sequence<0, 1, 2>(); + |""".stripMargin) + val List(seq, seq2) = cpg.local.l + seq.name shouldBe "seq" + // CDT is unable to deduce the type of the template argument + seq.typeFullName shouldBe "integer_sequence" + seq2.name shouldBe "seq2" + seq2.typeFullName shouldBe "my_integer_sequence" + } + + "handle folding expressions (binary)" in { + val cpg = code(""" + |template + |bool logicalAnd(Args... args) { + | return (true && ... && args); + |} + |bool b = true; + |bool& b2 = b; + |logicalAnd(b, b2, true); // == true + |""".stripMargin) + val List(argsParam) = cpg.method.nameExact("logicalAnd").parameter.l + argsParam.name shouldBe "args" + argsParam.typeFullName shouldBe "Args" + argsParam.isVariadic shouldBe true + val List(retExpr) = cpg.method.nameExact("logicalAnd").ast.isReturn.astChildren.isCall.l + retExpr.name shouldBe ".fold" + retExpr.typeFullName shouldBe "bool" + retExpr.code shouldBe "(true && ... && args)" + retExpr.argument.code.l shouldBe List("true", "args") + } + + "handle folding expressions (unary)" in { + val cpg = code(""" + |template + |auto sum(Args... args) { + | return (... + args); + |} + |sum(1.0, 2.0f, 3); // == 6.0 + |""".stripMargin) + val List(argsParam) = cpg.method.nameExact("sum").parameter.l + argsParam.name shouldBe "args" + argsParam.typeFullName shouldBe "Args" + argsParam.isVariadic shouldBe true + val List(retExpr) = cpg.method.nameExact("sum").ast.isReturn.astChildren.isCall.l + retExpr.name shouldBe ".fold" + retExpr.typeFullName shouldBe "Args" + retExpr.code shouldBe "(... + args)" + retExpr.argument.code.l shouldBe List("args") + } + + "handle new rules for auto deduction from braced-init-list" in { + val cpg = code(""" + |auto x1 = {1, 2, 3}; // x1 is std::initializer_list + |auto x2 {3}; // x2 is int + |auto x3 {3.0}; // x3 is double + |""".stripMargin) + val List(x1, x2, x3) = cpg.local.l + x1.name shouldBe "x1" + x1.typeFullName shouldBe Defines.Any + x2.name shouldBe "x2" + x2.typeFullName shouldBe "int" + x3.name shouldBe "x3" + x3.typeFullName shouldBe "double" + + pendingUntilFixed { + // TODO: can not be determined without type information from includes + x1.typeFullName shouldBe "std::initializer_list" + } + } + + "handle constexpr lambda" in { + val cpg = code(""" + |auto identity = [](int n) constexpr { return n; }; + |constexpr auto add = [](int x, int y) { + | auto L = [=] { return x; }; + | auto R = [=] { return y; }; + | return [=] { return L() + R(); }; + |}; + |constexpr int addOne(int n) { + | return [n] { return n + 1; }(); + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted shouldBe List( + // TODO: fix return types of nested lambdas + "Test0.cpp:.0:ANY(int)", + "Test0.cpp:.1.2:ANY()", + "Test0.cpp:.1.3:ANY()", + "Test0.cpp:.1.4:ANY()", + "Test0.cpp:.1:ANY(int,int)", + "Test0.cpp:.addOne.5:ANY()", + "addOne:int(int)" + ) + } + + "handle lambda capture `this` by value" in { + val cpg = code(""" + |struct MyObj { + | int value {123}; + | auto getValueCopy() { + | return [*this] { return value; }; + | } + | auto getValueRef() { + | return [this] { return value; }; + | } + |}; + |MyObj mo; + |auto valueCopy = mo.getValueCopy(); + |auto valueRef = mo.getValueRef(); + |mo.value = 321; + |valueCopy(); // 123 + |valueRef(); // 321 + |""".stripMargin) + // TODO: we can not express these lambda types in the current schema + // We would need to add a new type for lambdas that capture `this` by value copy/ref. + cpg.method.nameExact("getValueCopy").methodReturn.typeFullName.l shouldBe List(Defines.Function) + cpg.method.nameExact("getValueRef").methodReturn.typeFullName.l shouldBe List(Defines.Function) + } + + "handle inline variables" in { + val cpg = code(""" + |// Disassembly example using compiler explorer. + |struct S1 { int x; }; + |inline S1 x1 = S{321}; // mov esi, dword ptr [x1] + | // x1: .long 321 + | + |S1 x2 = S1{123}; // mov eax, dword ptr [.L_ZZ4mainE2x2] + | // mov dword ptr [rbp - 8], eax + | // .L_ZZ4mainE2x2: .long 123 + | + |struct S2 { + | S2() : id{count++} {} + | ~S2() { count--; } + | int id; + | static inline int count{0}; // declare and initialize count to 0 within the class + |}; + |""".stripMargin) + cpg.local.map(l => (l.name, l.typeFullName)).toMap shouldBe Map("x1" -> "S1", "x2" -> "S1") + cpg.typeDecl.member.nameExact("count").typeFullName.l shouldBe List("int") + } + + "handle nested namespaces" in { + val cpg = code(""" + |namespace A1 { // old + | namespace B1 { + | namespace C1 { + | int i; + | } + | } + |} + | + |namespace A2::B2::C2 { // new + | int i; + |} + |""".stripMargin) + cpg.namespaceBlock.nameNot("").name.sorted shouldBe List("A1", "A2", "B1", "B2", "C1", "C2") + cpg.namespaceBlock.nameNot("").fullName.sorted shouldBe List( + "A1", + "A1.B1", + "A1.B1.C1", + "A2", + "A2.B2", + "A2.B2.C2" + ) + } + + "handle structured bindings" in { + val cpg = code(""" + |template + |struct pair { + | T1 x; + | T2 y; + |}; + | + |using Coordinate = pair; + |Coordinate origin() { + | return Coordinate{0, 0}; + |} + | + |void foo() { + | const auto [ x, y ] = origin(); + | x; // == 0 + | y; // == 0 + | + | std::unordered_map mapping; + | // fill the map ... + | + | // Destructure by reference. + | for (const auto& [key, value] : mapping) { + | // Do something with key and value + | } + |} + |""".stripMargin) + cpg.call.code.l should contain theSameElementsAs List( + "Coordinate{0, 0}", + "{0, 0}", + "anonymous_tmp_0 = = origin()", + "origin()", + "x = anonymous_tmp_0.x", + "y = anonymous_tmp_0.y", + "anonymous_tmp_0.x", + "anonymous_tmp_0.y", + "anonymous_tmp_1 = mapping", + "key = anonymous_tmp_1.key", + "value = anonymous_tmp_1.value", + "anonymous_tmp_1.key", + "anonymous_tmp_1.value" + ) + cpg.local.map(l => (l.name, l.typeFullName)).toMap shouldBe Map( + "x" -> "int", + "y" -> "int", + "anonymous_tmp_0" -> "pair", + "mapping" -> "std.unordered_map", + // fails to resolve the type of the structured bindings without C++ header files + "anonymous_tmp_1" -> "ANY", + "key" -> "ANY", + "value" -> "ANY" + ) + cpg.controlStructure + .controlStructureTypeExact(ControlStructureTypes.FOR) + .astChildren + .isLocal + .map(l => (l.name, l.typeFullName)) + .toMap shouldBe Map( + // fails to resolve the type of the structured bindings without C++ header files + "anonymous_tmp_1" -> "ANY", + "key" -> "ANY", + "value" -> "ANY" + ) + pendingUntilFixed { + cpg.local.map(l => (l.name, l.typeFullName)).toMap shouldBe Map( + "x" -> "int", + "y" -> "int", + "anonymous_tmp_0" -> "pair", + "mapping" -> "std.unordered_map", + "anonymous_tmp_1" -> "std.unordered_map", + "key" -> "int", + "value" -> "int" + ) + cpg.controlStructure + .controlStructureTypeExact(ControlStructureTypes.FOR) + .astChildren + .astChildren + .isLocal + .map(l => (l.name, l.typeFullName)) + .toMap shouldBe Map( + // fails to resolve the type of the structured bindings without C++ header files + "anonymous_tmp_1" -> "std.unordered_map", + "key" -> "int", + "value" -> "int" + ) + } + } + + "handle selection statements with initializer" in { + val cpg = code(""" + |void foo() { + | if (std::lock_guard lk(mx); v.empty()) { + | v.push_back(val); + | } + | + | switch (Foo gadget(args); auto s = gadget.status()) { + | case OK: gadget.zip(); break; + | case Bad: throw BadFoo(s.message()); + | } + |} + |""".stripMargin) + cpg.method + .nameExact("foo") + .block + .astChildren + .isExpression + .sortBy(_.argumentIndex) + .code + .map(StringUtils.normalizeSpace) + .l shouldBe List( + "std::lock_guard lk(mx)", + "if (std::lock_guard lk(mx); v.empty()) { v.push_back(val); }", + "gadget(args)", + "s = gadget.status()", + "switch (Foo gadget(args); auto s = gadget.status()) { case OK: gadget.zip(); break; case Bad: throw BadFoo(s.message()); }" + ) + } + + "handle constexpr if" in { + val cpg = code(""" + |template + |constexpr bool isIntegral() { + | if constexpr (std::is_integral::value) { + | return true; + | } else { + | return false; + | } + |} + |static_assert(isIntegral() == true); + |static_assert(isIntegral() == true); + |static_assert(isIntegral() == false); + |struct S {}; + |static_assert(isIntegral() == false); + |""".stripMargin) + cpg.method.nameExact("isIntegral").controlStructure.code.map(StringUtils.normalizeSpace).l shouldBe List( + "if constexpr (std::is_integral::value) { return true; } else { return false; }", + "else" + ) + } + + "handle UTF-8 character literals" in { + val cpg = code(""" + |void foo() { + | char x = u8'x'; + |} + |""".stripMargin) + pendingUntilFixed { + // TODO: not supported by the CDT parser at the moment + cpg.assignment.code.l shouldBe List("char x = u8'x'") + cpg.assignment.argument(2).isLiteral.code.l shouldBe List("u8'x'") + cpg.local.nameExact("x").typeFullName.l shouldBe List("char") + cpg.identifier.nameExact("x").typeFullName.l shouldBe List("char") + } + } + + "handle direct list initialization of enums" in { + val cpg = code(""" + |enum byte : unsigned char {}; + |byte b {0}; + |byte d = byte{1}; + |""".stripMargin) + cpg.local.nameExact("b").typeFullName.l shouldBe List("byte") + cpg.identifier.nameExact("b").typeFullName.l shouldBe List("byte") + cpg.local.nameExact("d").typeFullName.l shouldBe List("byte") + cpg.identifier.nameExact("d").typeFullName.l shouldBe List("byte") + } + + "handle fallthrough, nodiscard, maybe_unused attributes" in { + val cpg = code(""" + |void foo() { + | switch (n) { + | case 1: + | // ... + | [[fallthrough]]; + | case 2: + | // ... + | break; + | case 3: + | // ... + | [[fallthrough]]; + | default: + | // ... + | } + |} + | + |[[nodiscard]] bool do_something() { + | return is_success; // true for success, false for failure + |} + |struct [[nodiscard]] error_info { + | // ... + |}; + | + |void my_callback(std::string msg, [[maybe_unused]] bool error) { + | // Don't care if `msg` is an error message, just log it. + | log(msg); + |} + |""".stripMargin) + cpg.method + .nameExact("foo") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.SWITCH) + .astChildren + .isBlock + ._jumpTargetViaAstOut + .code + .l shouldBe List("case 1:", "case 2:", "case 3:", "default:") + cpg.method.nameExact("do_something").size shouldBe 1 + cpg.typeDecl.nameExact("error_info").size shouldBe 1 + cpg.method.nameExact("my_callback").parameter.name.l shouldBe List("msg", "error") + } + + "handle _has_include" in { + val cpg = code(""" + |#ifdef __has_include + |# if __has_include() + |# include + |# define have_optional 1 + |# elif __has_include() + |# include + |# define have_optional 1 + |# define experimental_optional + |# else + |# define have_optional 0 + |# endif + |#endif + | + |#ifdef __has_include + |# if __has_include() + |# include + |# elif __has_include() + |# include + |# else + |# error No suitable headers found. + |# endif + |#endif + |""".stripMargin) + .moreCode( + """ + |int x = 1; + |""".stripMargin, + "x.h" + ) + .moreCode( + """ + |int y = 1; + |""".stripMargin, + "y.h" + ) + cpg.imports.code.l shouldBe List( + "# include ", + "# include ", + "# include ", + "# include " + ) + cpg.local.name.l shouldBe List("x", "y") + } + + "handle class template argument deduction" in { + val cpg = code(""" + |template + |struct container { + | container(T t) {} + | template + | container(Iter beg, Iter end); + |}; + | + |template + |container(Iter b, Iter e) -> container::value_type>; + | + |void foo() { + | std::vector v{ 1, 2, 3 }; // deduces std::vector + | std::mutex mtx; + | auto lck = std::lock_guard{ mtx }; // deduces to std::lock_guard + | auto p = new std::pair{ 1.0, 2.0 }; // deduces to std::pair* + | + | container a{ 7 }; // OK: deduces container + | std::vector v{ 1.0, 2.0, 3.0 }; + | auto b = container{ v.begin(), v.end() }; // OK: deduces container + |} + |""".stripMargin) + cpg.local.nameExact("a").typeFullName.l shouldBe List("container") + cpg.local.nameExact("v").typeFullName.l shouldBe List( + "std.vector", // generic types are not deduced + "std.vector" + ) + pendingUntilFixed { + // CDT deduces them to ProblemType as there are no includes for std:: + cpg.local.nameExact("p").typeFullName.l shouldBe List("std.pair*") + cpg.local.nameExact("lck").typeFullName.l shouldBe List("std.lock_guard") + cpg.local.nameExact("b").typeFullName.l shouldBe List("container") + } + } + + } +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features20/Cpp20FeaturesTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features20/Cpp20FeaturesTests.scala new file mode 100644 index 000000000000..40e4c41be913 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/cpp/features20/Cpp20FeaturesTests.scala @@ -0,0 +1,513 @@ +package io.joern.c2cpg.cpp.features20 + +import io.joern.c2cpg.astcreation.Defines +import io.joern.c2cpg.parser.FileDefaults +import io.joern.c2cpg.testfixtures.AstC2CpgSuite +import io.shiftleft.codepropertygraph.generated.ControlStructureTypes +import io.shiftleft.semanticcpg.language.* + +class Cpp20FeaturesTests extends AstC2CpgSuite(fileSuffix = FileDefaults.CppExt) { + + "C++20 feature support" should { + + "handle coroutines" in { + val cpg = code(""" + |generator range(int start, int end) { + | while (start < end) { + | co_yield start; + | start++; + | } + |} + | + |task echo(socket s) { + | for (;;) { + | auto data = co_await s.async_read(); + | co_await async_write(s, data); + | } + |} + |""".stripMargin) + cpg.method + .nameExact("range") + .controlStructure + .astChildren + .isBlock + .astChildren + .isCall + .nameExact(".yield") + .size shouldBe 1 + + pendingUntilFixed { + // `auto data = co_await s.async_read();` can not be parsed yet. + // Hence, this co_await call is missing. + cpg.method + .nameExact("echo") + .controlStructure + .astChildren + .isBlock + .astChildren + .isCall + .nameExact(".await") + .size shouldBe 2 + } + } + + "handle concepts" in { + val cpg = code(""" + |// `T` is not limited by any constraints. + |template + |concept always_satisfied = true; + |// Limit `T` to integrals. + |template + |concept integral = std::is_integral_v; + |// Limit `T` to both the `integral` constraint and signedness. + |template + |concept signed_integral = integral && std::is_signed_v; + |// Limit `T` to both the `integral` constraint and the negation of the `signed_integral` constraint. + |template + |concept unsigned_integral = integral && !signed_integral; + | + |// Forms for function parameters: + |// `T` is a constrained type template parameter. + |template + |void f1(T v); + | + |// `T` is a constrained type template parameter. + |template + | requires my_concept + |void f2(T v); + | + |// `T` is a constrained type template parameter. + |template + |void f3(T v) requires my_concept; + | + |// `v` is a constrained deduced parameter. + |void f4(my_concept auto v); + | + |// `v` is a constrained non-type template parameter. + |template + |void f5(); + | + |void foo() { + | // Forms for lambdas: + | // `T` is a constrained type template parameter. + | auto l1 = [] (T v) { + | // ... + | }; + | // `T` is a constrained type template parameter. + | auto l2 = [] requires my_concept (T v) { + | // ... + | }; + | // `T` is a constrained type template parameter. + | auto l3 = [] (T v) requires my_concept { + | // ... + | }; + | // `v` is a constrained deduced parameter. + | auto l4 = [](my_concept auto v) { + | // ... + | }; + | // `v` is a constrained non-type template parameter. + | auto l5 = [] () { + | // ... + | }; + |} + | + |template + | requires my_concept // `requires` clause. + |void f6(T); + | + |template + |concept callable = requires (T f) { f(); }; // `requires` expression. + | + |template + | requires requires (T x) { x + x; } // `requires` clause and expression on same line. + |T add(T a, T b) { + | return a + b; + |} + | + |template + |concept C1 = requires(T x) { + | {*x} -> std::convertible_to; // the type of the expression `*x` is convertible to `T::inner` + | {x + 1} -> std::same_as; // the expression `x + 1` satisfies `std::same_as` + | {x * 1} -> std::convertible_to; // the type of the expression `x * 1` is convertible to `T` + |}; + | + |template + |concept C2 = requires(T x) { + | requires std::same_as; + |}; + |""".stripMargin) + // we can't express concepts withing the CPG but parsing constructs using them should not be hindered + // sadly, at some places (e.g., for parameters) it fails parsing + cpg.method.nameNot("").fullName.sorted.l shouldBe List( + "add:T(T,T)", + "f1:void(T)", + "f2:void(T)", + "f3:void(T)", + "f6:void(T)", + "foo:void()", + "requires:requires(T)" + ) + pendingUntilFixed { + cpg.method.nameNot("").fullName.sorted.l shouldBe List( + "add:T(T,T)", + "f1:void(T)", + "f2:void(T)", + "f3:void(T)", + "f4:void(ANY)", + "f5:void()", + "f6:void(T)", + // remaining lambdas ... + "foo:void()" + ) + } + } + + "handle three-way comparison" in { + val cpg = code(""" + |bool foo() { + | bool x = 1 <=> 2; + | return x; + |} + |""".stripMargin) + // three-way comparison operator can not be parsed at the moment + pendingUntilFixed { + cpg.assignment.code.l shouldBe List("bool x = 1 <=> 2") + // TODO: test for children AST elements + } + } + + "handle designated initializers" in { + val cpg = code(""" + |struct A { + | int x; + | int y; + | int z = 123; + |}; + | + |void foo() { + | A a {.x = 1, .z = 2}; // a.x == 1, a.y == 0, a.z == 2 + |} + |""".stripMargin) + cpg.assignment.code.sorted.l shouldBe List("a.x = 1", "a.z = 2", "z = 123") + pendingUntilFixed { + // CDT failed to assign the type A to a here so we are not + // able to identify all struct fields. Hence, no a.y = 0 assignment yet. + cpg.assignment.code.sorted.l shouldBe List("a.x = 1", "a.y = 0", "a.z = 2", "z = 123") + } + } + + "handle template syntax for lambdas" in { + val cpg = code(""" + |void foo() { + | auto f = [](std::vector v) { + | // ... + | }; + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted.l shouldBe List("foo:void()") + pendingUntilFixed { + // [](std::vector v) { ... } can not be parsed by CDT at the moment + cpg.method.nameNot("").fullName.sorted.l shouldBe List("foo:void()", "0") + } + } + + "handle range-based for loop with initializer" in { + val cpg = code(""" + |void foo() { + | for (auto v = list; auto& e : v) { + | std::cout << e; + | } + |} + |""".stripMargin) + pendingUntilFixed { + // range-based for loop with initializer can not be parsed at all by CDT at the moment + val List(v, e) = cpg.method.nameExact("foo").controlStructure.astChildren.isLocal.l + v.name shouldBe "v" + e.name shouldBe "e" + val List(vectorInitCall) = + cpg.method.nameExact("foo").controlStructure.astChildren.order(2).isCall.argument.isCall.l + vectorInitCall.argumentIndex shouldBe 2 + vectorInitCall.name shouldBe Defines.OperatorConstructorInitializer + vectorInitCall.argument.code.l shouldBe List("1", "2", "3") + } + } + + "handle likely and unlikely attributes" in { + val cpg = code(""" + |void foo() { + | switch (n) { + | case 1: + | case1(); + | break; + | [[likely]] case 2: // n == 2 is considered to be arbitrarily more + | case2() // likely than any other value of n + | break; + | } + | + | int random = get_random_number_between_x_and_y(0, 3); + | if (random > 0) [[likely]] { + | // body of if statement + | likelyIf(); + | } + | + | while (unlikely_truthy_condition) [[unlikely]] { + | // body of while statement + | unlikelyWhile() + | } + |} + |""".stripMargin) + val cases = + cpg.method + .nameExact("foo") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.SWITCH) + .astChildren + .isBlock + ._jumpTargetViaAstOut + cases.code.l shouldBe List("case 1:", "[[likely]] case 2:") + cpg.method + .nameExact("foo") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.SWITCH) + .ast + .isCall + .code + .l shouldBe List("case1()", "case2()") + + cpg.method + .nameExact("foo") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.IF) + .ast + .isCall + .code + .l shouldBe List("random > 0", "likelyIf()") + + cpg.method + .nameExact("foo") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.WHILE) + .ast + .isCall + .code + .l shouldBe List("unlikelyWhile()") + } + + "handle deprecate implicit capture of this" in { + val cpg = code(""" + |struct int_value { + | int n = 0; + | auto getter_fn() { + | return [=, *this]() { return n; }; + | } + |}; + |""".stripMargin) + // TODO: we can not express these lambda types in the current schema + // We would need to add a new type for lambdas that capture `this` + cpg.method.nameExact("getter_fn").methodReturn.typeFullName.l shouldBe List(Defines.Function) + } + + "handle class types in non-type template parameters" in { + val cpg = code(""" + |struct foo { + | foo() = default; + | constexpr foo(int) {} + |}; + | + |template + |auto get_foo() { + | return f; + |} + | + |void main() { + | get_foo(); // uses implicit constructor + | get_foo(); + |} + |""".stripMargin) + cpg.typeDecl.nameExact("foo").size shouldBe 1 + cpg.method.nameExact("get_foo").size shouldBe 1 + cpg.method.nameExact("main").ast.isCall.typeFullName.l shouldBe List("ANY", "foo") + cpg.method.nameExact("main").ast.isCall.methodFullName.l shouldBe List( + // we can not resolve the implicit constructor call case: + ".get_foo:(0)", + "get_foo:foo()" + ) + pendingUntilFixed { + cpg.method.nameExact("main").ast.isCall.typeFullName.l shouldBe List("foo", "foo") + cpg.method.nameExact("main").ast.isCall.methodFullName.l shouldBe List("get_foo:foo()", "get_foo:foo()") + } + } + + "handle constexpr virtual functions" in { + val cpg = code(""" + |struct X1 { + | virtual int f() const = 0; + |}; + | + |struct X2: public X1 { + | constexpr virtual int f() const { return 2; } + |}; + | + |struct X3: public X2 { + | virtual int f() const { return 3; } + |}; + | + |struct X4: public X3 { + | constexpr virtual int f() const { return 4; } + |}; + | + |void foo() { + | constexpr X4 x4; + | x4.f(); // == 4 + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted.l shouldBe List( + "X1.f:int()", + "X2.f:int()", + "X3.f:int()", + "X4.f:int()", + "foo:void()" + ) + cpg.method.nameExact("foo").local.typeFullName.l shouldBe List("X4") + } + + "handle explicit(bool)" in { + val cpg = code(""" + |struct foo { + | // Specify non-integral types (strings, floats, etc.) require explicit construction. + | template + | explicit(!std::is_integral_v) foo(T) {} + |}; + | + |void foo() { + | foo a = 123; + | foo c {"123"}; + |} + |""".stripMargin) + cpg.method.nameExact("foo").local.typeFullName.distinct.l shouldBe List("foo") + pendingUntilFixed { + // CDT can not parse explicit(bool) yet + cpg.method.nameNot("").fullName.sorted.l shouldBe List("A.foo:T(T)", "foo:void()") + } + } + + "handle immediate functions" in { + val cpg = code(""" + |consteval int sqr(int n) { + | return n * n; + |} + | + |void foo() { + | constexpr int r = sqr(100); + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted.l shouldBe List("foo:void()", "sqr:int(int)") + val List(rLocal) = cpg.method.nameExact("foo").local.l + rLocal.typeFullName shouldBe "int" + rLocal.code shouldBe "constexpr int r" + } + + "handle using enum" in { + val cpg = code(""" + |enum class rgba_color_channel { red, green, blue, alpha }; + | + |std::string_view to_string(rgba_color_channel my_channel) { + | switch (my_channel) { + | using enum rgba_color_channel; + | case red: return "red"; + | case green: return "green"; + | case blue: return "blue"; + | case alpha: return "alpha"; + | } + |} + |""".stripMargin) + val List(switchBlock) = cpg.method + .nameExact("to_string") + .controlStructure + .controlStructureTypeExact(ControlStructureTypes.SWITCH) + .astChildren + .isBlock + .l + switchBlock._jumpTargetViaAstOut.code.l shouldBe List("case red:", "case green:", "case blue:", "case alpha:") + pendingUntilFixed { + // using clause is not parsed by CDT at all yet + switchBlock._jumpTargetViaAstOut.astChildren.isCall.code.l shouldBe List( + "rgba_color_channel.red", + "rgba_color_channel.green", + "rgba_color_channel.blue", + "rgba_color_channel.alpha" + ) + } + } + + "handle lambda capture of parameter pack" in { + val cpg = code(""" + |template + |auto f1(Args&&... args){ + | // BY VALUE: + | return [...args = std::forward(args)] {}; + |} + | + |template + |auto f2(Args&&... args){ + | // BY REFERENCE: + | return [&...args = std::forward(args)] {}; + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted.l shouldBe List("f1:ANY(Args&&)", "f2:ANY(Args&&)") + cpg.method.nameNot("").signature.sorted.l shouldBe List("ANY(Args&&)", "ANY(Args&&)") + pendingUntilFixed { + // the actual return value (i.e., the lambda defined at the return) can not be parsed by CDT + cpg.method.nameExact("f1").ast.isReturn.astChildren.isMethodRef.code.l shouldBe List( + "[...args = std::forward(args)] {};" + ) + cpg.method.nameExact("f2").ast.isReturn.astChildren.isMethodRef.code.l shouldBe List( + "[&...args = std::forward(args)] {};" + ) + } + } + + "handle char8_t" in { + val cpg = code(""" + |char8_t utf8_str[] = u8"\u0123"; + |""".stripMargin) + val List(assignmentCall) = cpg.call.l + assignmentCall.code shouldBe """utf8_str[] = u8"\u0123"""" + val List(utf8_str) = assignmentCall.argument.isIdentifier.l + utf8_str.name shouldBe "utf8_str" + utf8_str.typeFullName shouldBe "char8_t[6]" + val List(u8) = assignmentCall.argument.isLiteral.l + u8.code shouldBe """u8"\u0123"""" + u8.typeFullName shouldBe "char[6]" + } + + "handle constinit" in { + val cpg = code(""" + |constexpr const char* f(bool p) { return p ? "constant initializer" : g(); } + | + |void foo() { + | constinit const char *c = f(true); + |} + |""".stripMargin) + cpg.method.nameNot("").fullName.sorted.l shouldBe List("f:char*(bool)", "foo:void()") + val List(cLocal) = cpg.method.nameExact("foo").local.l + cLocal.typeFullName shouldBe "char*" + cLocal.code shouldBe "const char *c" // constinit keyword is not parsed by CDT + } + + "handle __VA_OPT__" in { + val cpg = code(""" + |#define F(...) f(0 __VA_OPT__(,) __VA_ARGS__) + |void foo() { + | F(a, b, c); // replaced by f(0, a, b, c) + | F(); // replaced by f(0) + |} + |""".stripMargin) + pendingUntilFixed { + // Impossible to test without C++ system headers for __VA_OPT__ and __VA_ARGS__ definitions + cpg.call.code.l shouldBe List("f(a, b, c)", "f()") + } + } + + } +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/DataFlowTests.scala index a050ffde48cd..40ec4942a5a9 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/DataFlowTests.scala @@ -2,12 +2,13 @@ package io.joern.c2cpg.dataflow import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite import io.joern.dataflowengineoss.language.* -import io.joern.dataflowengineoss.queryengine.{EngineConfig, EngineContext} +import io.joern.dataflowengineoss.queryengine.EngineConfig +import io.joern.dataflowengineoss.queryengine.EngineContext import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes.{CfgNode, Identifier, Literal} +import io.shiftleft.codepropertygraph.generated.nodes.CfgNode +import io.shiftleft.codepropertygraph.generated.nodes.Identifier +import io.shiftleft.codepropertygraph.generated.nodes.Literal import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Table.AvailableWidthProvider -import overflowdb.traversal.toNodeTraversal class DataFlowTests extends DataFlowCodeToCpgSuite { @@ -1082,7 +1083,7 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { cpg .call("bar") .outE(EdgeTypes.REACHING_DEF) - .count(_.inNode() == cpg.ret.head) shouldBe 1 + .count(_.dst == cpg.ret.head) shouldBe 1 } } @@ -1989,4 +1990,200 @@ class DataFlowTestsWithCallDepth extends DataFlowCodeToCpgSuite { ) } } + + "DataFlowTest73" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x%=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x%=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x%=2", 4), ("call1(x%=2)", 4))) + } + + "the literal in x%=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x%=2", 4), ("call2(x)", 5))) + } + + } + + "DataFlowTest74" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x^=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x^=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x^=2", 4), ("call1(x^=2)", 4))) + } + + "the literal in x^=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x^=2", 4), ("call2(x)", 5))) + } + } + + "DataFlowTest75" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x|=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x|=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x|=2", 4), ("call1(x|=2)", 4))) + } + + "the literal in x|=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x|=2", 4), ("call2(x)", 5))) + } + } + + "DataFlowTest76" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x&=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x&=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x&=2", 4), ("call1(x&=2)", 4))) + } + + "the literal in x&=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x&=2", 4), ("call2(x)", 5))) + } + } + + "DataFlowTest77" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x<<=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x<<=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x<<=2", 4), ("call1(x<<=2)", 4))) + } + + "the literal in x<<=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x<<=2", 4), ("call2(x)", 5))) + } + } + + "DataFlowTest78" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | call1(x>>=2); + | call2(x); + |} + |""".stripMargin) + + "the literal in x>>=2 should taint the outer expression" in { + val source = cpg.literal("2") + val sink = cpg.call("call1") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x>>=2", 4), ("call1(x>>=2)", 4))) + } + + "the literal in x>>=2 should taint the next occurrence of x" in { + val source = cpg.literal("2") + val sink = cpg.call("call2") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List(List(("x>>=2", 4), ("call2(x)", 5))) + } + } + + "DataFlowTest79" should { + val cpg = code(""" + |int main(void) { + | int x = 5; + | int y = 2; + | int z = x % y; + | call1(z); + |} + |""".stripMargin) + + "the first argument in a % operation should not taint its second argument" in { + val source = cpg.literal("5") + val sink = cpg.identifier("y").lineNumber(5) + sink.reachableByFlows(source) shouldBe empty + } + + "the second argument in a % operation should not taint its first argument" in { + val source = cpg.literal("2") + val sink = cpg.identifier("x").lineNumber(5) + sink.reachableByFlows(source) shouldBe empty + } + + "the arguments in a % operation should taint its return value" in { + val source = cpg.literal + val sink = cpg.call("call1").argument + sink.reachableByFlows(source).map(flowToResultPairs).toSetMutable shouldBe Set( + List(("x = 5", 3), ("x % y", 5), ("z = x % y", 5), ("call1(z)", 6)), + List(("y = 2", 4), ("x % y", 5), ("z = x % y", 5), ("call1(z)", 6)) + ) + } + } + + "DataFlowTest80" should { + val cpg = code(""" + |int main(void) { + | int x = 10; + | int y = 20; + | int z[] = {x, y, 30}; + | call1(z); + |} + |""".stripMargin) + + "elements of an arrayInitializer should not taint each other" in { + val x = cpg.identifier("x").lineNumber(5).l + val y = cpg.identifier("y").lineNumber(5).l + val z = cpg.literal("30").l + x.reachableByFlows(y ++ z) shouldBe empty + y.reachableByFlows(x ++ z) shouldBe empty + z.reachableByFlows(x ++ y) shouldBe empty + } + + "elements of an arrayInitializer should taint its return value" in { + val x = cpg.literal("10") + val y = cpg.literal("20") + val z = cpg.literal("30") + cpg.call("call1").argument.reachableByFlows(x ++ y ++ z).map(flowToResultPairs).toSetMutable shouldBe Set( + List(("x = 10", 3), ("{x, y, 30}", 5), ("z[] = {x, y, 30}", 5), ("call1(z)", 6)), + List(("y = 20", 4), ("{x, y, 30}", 5), ("z[] = {x, y, 30}", 5), ("call1(z)", 6)), + List(("{x, y, 30}", 5), ("z[] = {x, y, 30}", 5), ("call1(z)", 6)) + ) + } + } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/ReachingDefTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/ReachingDefTests.scala index c1c41699ed98..c8c0771dc094 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/ReachingDefTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/dataflow/ReachingDefTests.scala @@ -3,7 +3,7 @@ package io.joern.c2cpg.dataflow import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite import io.joern.dataflowengineoss.passes.reachingdef.ReachingDefFlowGraph import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ReachingDefTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/C2CpgHTTPServerTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/C2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..b39faa919658 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/C2CpgHTTPServerTests.scala @@ -0,0 +1,83 @@ +package io.joern.c2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable + +class C2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("c2cpgTestsHttpTest") + val file = dir / "main.c" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |int main$indexStr(int argc, char *argv[]) { + | print("Hello World!"); + |} + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.c2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.c2cpg.Main.stop() + } + + "Using c2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("c2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l shouldBe List("""print("Hello World!")""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("c2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain(s"main$index") + cpg.call.code.l shouldBe List("""print("Hello World!")""") + } + } + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/CodeDumperFromFileTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/CodeDumperFromFileTests.scala index e05a794841d5..fcd7e87c1502 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/CodeDumperFromFileTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/CodeDumperFromFileTests.scala @@ -3,7 +3,7 @@ package io.joern.c2cpg.io import better.files.File import io.joern.c2cpg.testfixtures.C2CpgSuite import io.shiftleft.semanticcpg.codedumper.CodeDumper -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/ExcludeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/ExcludeTests.scala index 4d750ebda36e..16a9d1ebcf13 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/ExcludeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/ExcludeTests.scala @@ -4,7 +4,7 @@ import better.files.File import io.joern.c2cpg.Config import io.joern.c2cpg.C2Cpg import io.joern.x2cpg.X2Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import org.scalatest.matchers.should.Matchers import org.scalatest.prop.TableDrivenPropertyChecks @@ -62,6 +62,25 @@ class ExcludeTests extends AnyWordSpec with Matchers with TableDrivenPropertyChe ) } + "Using case sensitive excludes" should { + "exclude the given files correctly" in { + if (scala.util.Properties.isWin || scala.util.Properties.isMac) { + // both are written uppercase and are ignored nevertheless because + // the file systems are case-insensitive by default + testWithArguments(Seq("Folder", "Index.c"), "", Set("a.c", "foo.bar/d.c")) + } + if (scala.util.Properties.isLinux) { + // both are written uppercase and are not ignored because + // ext3/ext4 and many other Linux filesystems are case-sensitive by default + testWithArguments( + Seq("Folder", "Index.c"), + "", + Set("a.c", "folder/b.c", "folder/c.c", "foo.bar/d.c", "index.c") + ) + } + } + } + "Using different excludes via program arguments" should { val testInput = Table( diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/FileHandlingTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/FileHandlingTests.scala new file mode 100644 index 000000000000..7caf00074c23 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/FileHandlingTests.scala @@ -0,0 +1,75 @@ +package io.joern.c2cpg.io + +import better.files.File +import io.joern.c2cpg.parser.FileDefaults +import io.joern.c2cpg.testfixtures.CDefaultTestCpg +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.x2cpg.testfixtures.Code2CpgFixture +import io.shiftleft.semanticcpg.language.* + +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +object FileHandlingTests { + private val brokenLinkedFile: String = "broken.c" + private val cyclicLinkedFile: String = "loop.c" +} + +class FileHandlingTests + extends Code2CpgFixture(() => + new CDefaultTestCpg(FileDefaults.CExt) { + override def codeFilePreProcessing(codeFile: Path): Unit = { + if (codeFile.toString.endsWith(FileHandlingTests.brokenLinkedFile)) { + File(codeFile).delete().symbolicLinkTo(File("does/not/exist.c")) + } + if (codeFile.toString.endsWith(FileHandlingTests.cyclicLinkedFile)) { + val dir = File(codeFile).delete().parent + val folderA = Paths.get(dir.toString(), "FolderA") + val folderB = Paths.get(dir.toString(), "FolderB") + val symlinkAtoB = folderA.resolve("LinkToB") + val symlinkBtoA = folderB.resolve("LinkToA") + Files.createDirectory(folderA) + Files.createDirectory(folderB) + Files.createSymbolicLink(symlinkAtoB, folderB) + Files.createSymbolicLink(symlinkBtoA, folderA) + } + } + } + .withOssDataflow(false) + .withSemantics(DefaultSemantics()) + .withPostProcessingPasses(false) + ) { + + "File handling 1" should { + val cpg = code( + """ + |int a() {} + |""".stripMargin, + "a.c" + ).moreCode("", FileHandlingTests.brokenLinkedFile) + + "not crash on broken symlinks" in { + val fileNames = cpg.file.name.l + fileNames should contain("a.c").and(not contain FileHandlingTests.brokenLinkedFile) + } + + } + + "File handling 2" should { + val cpg = code( + """ + |int a() {} + |""".stripMargin, + "a.c" + ).moreCode("", FileHandlingTests.cyclicLinkedFile) + + "not crash on cyclic symlinks" in { + val fileNames = cpg.file.name.l + fileNames should contain("a.c") + } + + } + +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/JSONCompilationDatabaseParserTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/JSONCompilationDatabaseParserTests.scala new file mode 100644 index 000000000000..0c786889fd63 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/JSONCompilationDatabaseParserTests.scala @@ -0,0 +1,164 @@ +package io.joern.c2cpg.io + +import better.files.File +import io.joern.c2cpg.parser.JSONCompilationDatabaseParser +import io.joern.c2cpg.C2Cpg +import io.joern.c2cpg.Config +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.types.structure.FileTraversal +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import java.nio.file.Paths + +class JSONCompilationDatabaseParserTests extends AnyWordSpec with Matchers { + + private def newProjectUnderTest(): File = { + val dir = File.newTemporaryDirectory("c2cpgJSONCompilationDatabaseParserTests") + + val mainText = + """ + |int main(int argc, char *argv[]) { + | print("Hello World!"); + |} + |#ifdef SOMEDEFA + |void foo() {} + |#endif + |#ifdef SOMEDEFC + |void bar() {} + |#endif + |""".stripMargin + + val fileA = dir / "fileA.c" + fileA.createIfNotExists(createParents = true) + fileA.writeText(mainText) + fileA.deleteOnExit() + val fileB = dir / "fileB.c" + fileB.createIfNotExists(createParents = true) + fileB.writeText(mainText) + fileB.deleteOnExit() + val fileC = dir / "fileC.c" + fileC.createIfNotExists(createParents = true) + fileC.writeText(mainText) + fileC.deleteOnExit() + + val compilerCommands = dir / "compile_commands.json" + compilerCommands.createIfNotExists(createParents = true) + val content = + s""" + |[ + | { "directory": "${dir.pathAsString}", + | "arguments": ["/usr/bin/clang++", "-Irelative", "-DSOMEDEFA=With spaces, quotes and \\-es.", "-c", "-o", "fileA.o", "fileA.cc"], + | "file": "fileA.c" }, + | { "directory": ".", + | "arguments": ["/usr/bin/clang++", "-Irelative", "-DSOMEDEFB=With spaces, quotes and \\-es.", "-c", "-o", "fileB.o", "fileB.cc"], + | "file": "${fileB.pathAsString}" } + |]""".stripMargin.replace("\\", "\\\\") // escape for tests under Windows + compilerCommands.writeText(content) + compilerCommands.deleteOnExit() + + dir.deleteOnExit() + } + + private def newBrokenProjectUnderTest(): File = { + val dir = File.newTemporaryDirectory("c2cpgJSONCompilationDatabaseParserTests") + + val mainText = + """ + |int main(int argc, char *argv[]) { + | print("Hello World!"); + |} + |""".stripMargin + + val fileA = dir / "fileA.c" + fileA.createIfNotExists(createParents = true) + fileA.writeText(mainText) + fileA.deleteOnExit() + + val compilerCommands = dir / "compile_commands.json" + compilerCommands.createIfNotExists(createParents = true) + val content = + s""" + |[ + | { "directory": "${dir.pathAsString}", + | "arguments": ["/usr/bin/clang++", "-Irelative", "-DSOMEDEFA=With spaces, quotes and \\-es.", "-c", "-o", "fileA.o", "fileA.cc"], + | "file": "fileA.c" }, + | { "directory": "/does/not/exist", + | "arguments": ["/usr/bin/clang++", "-c", "-o", "fileB.o", "name.cpp"], + | "file": "name.cpp" } + |]""".stripMargin.replace("\\", "\\\\") // escape for tests under Windows + compilerCommands.writeText(content) + compilerCommands.deleteOnExit() + + dir.deleteOnExit() + } + + "Parsing a simple compile_commands.json" should { + "generate a proper list of CommandObjects" in { + val content = + """ + |[ + | { "directory": "/home/user/llvm/build", + | "arguments": ["/usr/bin/clang++", "-I/usr/include", "-I./include", "-DSOMEDEFA=With spaces, quotes and \\-es.", "-c", "-o", "file.o", "file.cc"], + | "file": "file.cc" }, + | { "directory": "/home/user/llvm/build", + | "command": "/usr/bin/clang++ -I/home/user/project/includes -DSOMEDEFB=\"With spaces, quotes and \\-es.\" -DSOMEDEFC -c -o file.o file.cc", + | "file": "file2.cc" } + |]""".stripMargin + + File.usingTemporaryFile("compile_commands.json") { commandJsonFile => + commandJsonFile.writeText(content) + + val commandObjects = JSONCompilationDatabaseParser.parse(commandJsonFile.pathAsString) + commandObjects.map(_.compiledFile()) shouldBe Set( + Paths.get("/home/user/llvm/build/file.cc").toString, + Paths.get("/home/user/llvm/build/file2.cc").toString + ) + commandObjects.flatMap(_.defines()) shouldBe Set( + ("SOMEDEFA", "With spaces, quotes and \\-es."), + ("SOMEDEFB", "\"With spaces, quotes and \\-es.\""), + ("SOMEDEFC", "") + ) + commandObjects.flatMap(_.includes()) shouldBe Set("/usr/include", "./include", "/home/user/project/includes") + } + } + } + + "Using a simple compile_commands.json" should { + "respect the files listed" in { + val cpgOutFile = File.newTemporaryFile("c2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val config = Config() + .withInputPath(input) + .withOutputPath(output) + .withCompilationDatabase((File(input) / "compile_commands.json").pathAsString) + val c2cpg = new C2Cpg() + val cpg = c2cpg.createCpg(config).get + cpg.file.nameNot(FileTraversal.UNKNOWN, "").name.sorted.l should contain theSameElementsAs List( + "fileA.c", + "fileB.c" + // fileC.c is ignored because it is not listed in the compile_commands.json + ) + cpg.method.nameNot("").name.sorted.l shouldBe List("foo", "main", "main") + } + + "handle broken file paths" in { + val cpgOutFile = File.newTemporaryFile("c2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newBrokenProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val config = Config() + .withInputPath(input) + .withOutputPath(output) + .withCompilationDatabase((File(input) / "compile_commands.json").pathAsString) + val c2cpg = new C2Cpg() + val cpg = c2cpg.createCpg(config).get + cpg.file.nameNot(FileTraversal.UNKNOWN, "").name.l shouldBe List("fileA.c") + cpg.method.nameNot("").name.l shouldBe List("main") + } + } +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/LogFromCCorePluginTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/LogFromCCorePluginTests.scala new file mode 100644 index 000000000000..1f90693353f6 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/LogFromCCorePluginTests.scala @@ -0,0 +1,28 @@ +package io.joern.c2cpg.io + +import io.joern.c2cpg.testfixtures.C2CpgSuite +import io.shiftleft.semanticcpg.language.* +import org.eclipse.cdt.core.CCorePlugin +import org.eclipse.cdt.internal.core.parser.ParserException + +class LogFromCCorePluginTests extends C2CpgSuite { + + private val codeString = """ + |// A comment + |int my_func(int param1) + |{ + | int x = foo(param1); + |}""".stripMargin + + private val cpg = code(codeString) + + "logging from CCorePlugin" should { + + "not crash with an exception" in { + noException should be thrownBy CCorePlugin.log(new ParserException("Test Exception!")) + val List(func) = cpg.method.nameExact("my_func").l + func.fullName shouldBe "my_func" + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotAstGeneratorTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotAstGeneratorTests.scala index 76a1bc882684..955406b66103 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotAstGeneratorTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotAstGeneratorTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.io.dotgenerator import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DotAstGeneratorTests extends C2CpgSuite { @@ -32,7 +32,10 @@ class DotAstGeneratorTests extends C2CpgSuite { inside(cpg.method.name("my_func").dotAst.l) { case List(x) => x should ( startWith("digraph \"my_func\"") and - include("""[label = <(CONTROL_STRUCTURE,IF,if (y > 42))5> ]""") and + include( + """[label = IF
if (y > 42) { return y; } else { return sqrt(y); }> ]""" + ) and + include("""[label = 42
y > 42> ]""") and endWith("}\n") ) } @@ -52,7 +55,7 @@ class DotAstGeneratorTests extends C2CpgSuite { "allow plotting sub trees of methods" in { inside(cpg.method.ast.isControlStructure.code(".*y > 42.*").dotAst.l) { case List(x, _) => - x should (include("y > 42") and include("IDENTIFIER,y") and not include "x * 2") + x should (include("y > 42") and include("IDENTIFIER, 5
y") and not include "x * 2") } } @@ -60,9 +63,9 @@ class DotAstGeneratorTests extends C2CpgSuite { inside(cpg.method.name("lemon").dotAst.l) { case List(x) => x should ( startWith("digraph \"lemon\"") and - include("""[label = <(goog,goog("\"yes\""))18> ]""") and + include("""[label = goog("\"yes\"")> ]""") and include( - """[label = <(LITERAL,"\"yes\"",goog("\"yes\""))18> ]""" + """[label = "\"yes\""
goog("\"yes\"")> ]""" ) and endWith("}\n") ) diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCdgGeneratorTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCdgGeneratorTests.scala index 08778bd8ec6e..437d81601b56 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCdgGeneratorTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCdgGeneratorTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.io.dotgenerator import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DotCdgGeneratorTests extends DataFlowCodeToCpgSuite { @@ -19,9 +19,9 @@ class DotCdgGeneratorTests extends DataFlowCodeToCpgSuite { inside(cpg.method.name("foo").dotCdg.l) { case List(x) => x should ( startWith("digraph \"foo\"") and - include("""[label = <(<operator>.greaterThan,x > 8)3> ]""") and - include("""[label = <(<operator>.assignment,z = a(x))4> ]""") and - include("""[label = <(a,a(x))4> ]""") and + include("""[label = <<operator>.greaterThan, 3
x > 8> ]""") and + include("""[label = <<operator>.assignment, 4
z = a(x)> ]""") and + include("""[label = a(x)> ]""") and endWith("}\n") ) val lines = x.split("\n") @@ -46,9 +46,9 @@ class DotCdgGeneratorTests extends DataFlowCodeToCpgSuite { inside(cpg.method.name("foo").dotCdg.l) { case List(x) => x should ( startWith("digraph \"foo\"") and - include("""[label = <(<operator>.greaterThan,x > 8)3> ]""") and - include("""[label = <(<operator>.assignment,z = a(x))4> ]""") and - include("""[label = <(a,a(x))4> ]""") and + include("""[label = <<operator>.greaterThan, 3
x > 8> ]""") and + include("""[label = <<operator>.assignment, 4
z = a(x)> ]""") and + include("""[label = a(x)> ]""") and endWith("}\n") ) val lines = x.split("\n") diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCfgGeneratorTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCfgGeneratorTests.scala index 4cd62011ce76..8b6744e83aac 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCfgGeneratorTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotCfgGeneratorTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.io.dotgenerator import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DotCfgGeneratorTests extends C2CpgSuite { @@ -21,7 +21,7 @@ class DotCfgGeneratorTests extends C2CpgSuite { inside(cpg.method.name("main").dotCfg.l) { case List(dotStr) => dotStr should ( startWith("digraph \"main\" {") and - include("(<operator>.assignment,i = 0)") and + include("[label = <<operator>.assignment, 3
i = 0> ]") and endWith("}\n") ) } @@ -70,15 +70,9 @@ class DotCfgGeneratorTests extends C2CpgSuite { val cpg = code(""" |int example(int a, int b, int c) { | int x = 3; - | if(a) { - | foo(); - | } - | if(b) { - | foo_2(); - | } - | if (c) { - | foo_3(); - | } + | if(a) { foo(); } + | if(b) { foo_2(); } + | if (c) { foo_3(); } |} |""".stripMargin) @@ -86,9 +80,9 @@ class DotCfgGeneratorTests extends C2CpgSuite { inside(cpg.method.name("example").dotCfg.l) { case List(dotStr) => dotStr should ( startWith("digraph \"example\" {") and - include("<(IDENTIFIER,a,if (a))4>") and - include("<(IDENTIFIER,b,if (b))7>") and - include("<(IDENTIFIER,c,if (c))10>") and + include("[label = a
if(a) { foo(); }> ]") and + include("[label = b
if(b) { foo_2(); }> ]") and + include("[label = c
if (c) { foo_3(); }> ]") and endWith("}\n") ) } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotDdgGeneratorTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotDdgGeneratorTests.scala index 846ca622d3d7..f043b7fbe7f5 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotDdgGeneratorTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/io/dotgenerator/DotDdgGeneratorTests.scala @@ -1,8 +1,8 @@ package io.joern.c2cpg.io.dotgenerator import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* class DotDdgGeneratorTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/macros/MacroHandlingTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/macros/MacroHandlingTests.scala index 56192a053e2f..18a1f55935d6 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/macros/MacroHandlingTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/macros/MacroHandlingTests.scala @@ -2,13 +2,13 @@ package io.joern.c2cpg.macros import io.joern.c2cpg.testfixtures.C2CpgSuite import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.Block import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MacroHandlingTests extends C2CpgSuite { @@ -231,7 +231,7 @@ class MacroHandlingTests extends C2CpgSuite { """.stripMargin) "should not result in malformed CFGs when expanding a nested macro with block" in { - cpg.all.collectAll[Block].l.count(b => b.cfgOut.size > 1) shouldBe 0 + cpg.all.collectAll[Block].l.count(b => b._cfgOut.size > 1) shouldBe 0 } } @@ -298,6 +298,64 @@ class MacroHandlingTests extends C2CpgSuite { typeNumCall.columnNumber shouldBe Some(11) } } + + "MacroHandlingTests10" should { + + "have ast parents" in { + val cpg = code(""" + |#define FFSWAP(type,a,b) do{type SWAP_tmp=b; b=a; a=SWAP_tmp;}while(0) + |struct elem_to_channel { + | uint64_t av_position; + | uint8_t syn_ele; + | uint8_t elem_id; + | uint8_t aac_position; + |}; + |int main () { + | struct elem_to_channel e2c_vec[4 * 1] = { { 0 } }; + | int i = 1; + | FFSWAP(struct elem_to_channel, e2c_vec[i - 1], e2c_vec[i]); + |} + |""".stripMargin) + cpg.local.count(l => l._astIn.isEmpty) shouldBe 0 + cpg.local.count(l => l._astIn.size > 1) shouldBe 0 + cpg.local.count(l => l._astIn.size == 1) shouldBe 3 + } + + "only have locals with exactly one ast parent" in { + val cpg = code( + """ + |#define deleteReset(ptr) do { delete ptr; ptr = nullptr; } while(0) + |void func(void) { + | int *foo = new int; + | int *bar = new int; + | int *baz = new int; + | deleteReset(foo); + | deleteReset(bar); + | deleteReset(baz); + |} + |""".stripMargin, + "foo.cc" + ) + val List(foo) = cpg.local.nameExact("foo").l + foo._astIn.size shouldBe 1 + val List(bar) = cpg.local.nameExact("bar").l + bar._astIn.size shouldBe 1 + val List(baz) = cpg.local.nameExact("baz").l + baz._astIn.size shouldBe 1 + } + + "only have local for declaration and not from assignment from broken macro" in { + val cpg = code(""" + |#define FOO() (long)va_arg(ap, int) + |void func(void) { + | int foo; + | foo = FOO(); + | foo = FOO(); + |} + |""".stripMargin) + cpg.local.nameExact("foo").size shouldBe 1 + } + } } class CfgMacroTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/MetaDataPassTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/MetaDataPassTests.scala index 2068e17210ff..9996ac6b6995 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/MetaDataPassTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/MetaDataPassTests.scala @@ -2,14 +2,12 @@ package io.joern.c2cpg.passes import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import io.joern.x2cpg.passes.frontend.MetaDataPass import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import scala.jdk.CollectionConverters._ - class MetaDataPassTests extends AnyWordSpec with Matchers { "MetaDataPass" should { @@ -17,11 +15,11 @@ class MetaDataPassTests extends AnyWordSpec with Matchers { new MetaDataPass(cpg, Languages.C, "").createAndApply() "create exactly two nodes" in { - cpg.graph.V.asScala.size shouldBe 2 + cpg.graph.allNodes.size shouldBe 2 } "create no edges" in { - cpg.graph.E.asScala.size shouldBe 0 + cpg.graph.allNodes.outE.size shouldBe 0 } "create a metadata node with correct language" in { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/AstCreationPassTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/AstCreationPassTests.scala index 1fa6e7de118c..5b6d07a90b59 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/AstCreationPassTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/AstCreationPassTests.scala @@ -12,7 +12,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -import overflowdb.traversal.toNodeTraversal class AstCreationPassTests extends AstC2CpgSuite { @@ -31,6 +30,28 @@ class AstCreationPassTests extends AstC2CpgSuite { } } + "be correct for full names and signatures for method problem bindings" in { + val cpg = code( + """ + |char tpe::foo(char_type a, char b) const { + | return static_cast(a); + |} + |const wchar_t* tpe::foo(const char_type* a, const char_type* b, char c, char* d) const { + | return a; + |} + |""".stripMargin, + "foo.cpp" + ) + // tpe can't be resolved for both methods resulting in problem bindings. + // We can however manually reconstruct the signature from the params and return type without + // relying on the resolved function binding signature. + val List(foo1, foo2) = cpg.method.nameExact("foo").l + foo1.fullName shouldBe "tpe.foo:char(char_type,char)" + foo1.signature shouldBe "char(char_type,char)" + foo2.fullName shouldBe "tpe.foo:const wchar_t*(char_type*,char_type*,char,char*)" + foo2.signature shouldBe "const wchar_t*(char_type*,char_type*,char,char*)" + } + "be correct for packed args" in { val cpg = code( """ @@ -99,246 +120,9 @@ class AstCreationPassTests extends AstC2CpgSuite { } } - "be correct for simple lambda expressions" in { - val cpg = code( - """ - |auto x = [] (int a, int b) -> int - | { return a + b; }; - |auto y = [] (string a, string b) -> string - | { return a + b; }; - |""".stripMargin, - "test.cpp" - ) - val lambda1FullName = "0" - val lambda2FullName = "1" - - cpg.local.name("x").order.l shouldBe List(1) - cpg.local.name("y").order.l shouldBe List(3) - - inside(cpg.assignment.l) { case List(assignment1, assignment2) => - assignment1.order shouldBe 2 - inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => - ref.methodFullName shouldBe lambda1FullName - } - assignment2.order shouldBe 4 - inside(assignment2.astMinusRoot.isMethodRef.l) { case List(ref) => - ref.methodFullName shouldBe lambda2FullName - } - } - - inside(cpg.method.fullNameExact(lambda1FullName).isLambda.l) { case List(l1) => - l1.name shouldBe lambda1FullName - l1.code should startWith("[] (int a, int b) -> int") - l1.signature shouldBe s"int(int,int)" - l1.body.code shouldBe "{ return a + b; }" - } - - inside(cpg.method.fullNameExact(lambda2FullName).isLambda.l) { case List(l2) => - l2.name shouldBe lambda2FullName - l2.code should startWith("[] (string a, string b) -> string") - l2.signature shouldBe s"string(string,string)" - l2.body.code shouldBe "{ return a + b; }" - } - - inside(cpg.typeDecl(NamespaceTraversal.globalNamespaceName).head.bindsOut.l) { - case List(bX: Binding, bY: Binding) => - bX.name shouldBe lambda1FullName - bX.signature shouldBe s"int(int,int)" - inside(bX.refOut.l) { case List(method: Method) => - method.name shouldBe lambda1FullName - method.fullName shouldBe lambda1FullName - method.signature shouldBe s"int(int,int)" - } - bY.name shouldBe lambda2FullName - bY.signature shouldBe s"string(string,string)" - inside(bY.refOut.l) { case List(method: Method) => - method.name shouldBe lambda2FullName - method.fullName shouldBe lambda2FullName - method.signature shouldBe s"string(string,string)" - } - } - } - - "be correct for simple lambda expression in class" in { - val cpg = code( - """ - |class Foo { - | auto x = [] (int a, int b) -> int - | { - | return a + b; - | }; - |}; - | - |""".stripMargin, - "test.cpp" - ) - val lambdaName = "0" - val lambdaFullName = s"Foo.$lambdaName" - val signature = s"int(int,int)" - - cpg.member.name("x").order.l shouldBe List(1) - - inside(cpg.assignment.l) { case List(assignment1) => - inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => - ref.methodFullName shouldBe lambdaFullName - } - } - - inside(cpg.method.fullNameExact(lambdaFullName).isLambda.l) { case List(l1) => - l1.name shouldBe lambdaName - l1.code should startWith("[] (int a, int b) -> int") - l1.signature shouldBe signature - } - - inside(cpg.typeDecl("Foo").head.bindsOut.l) { case List(binding: Binding) => - binding.name shouldBe lambdaName - binding.signature shouldBe signature - inside(binding.refOut.l) { case List(method: Method) => - method.name shouldBe lambdaName - method.fullName shouldBe lambdaFullName - method.signature shouldBe signature - } - } - } - - "be correct for simple lambda expression in class under namespaces" in { - val cpg = code( - """ - |namespace A { class B { - |class Foo { - | auto x = [] (int a, int b) -> int - | { - | return a + b; - | }; - |}; - |};} - |""".stripMargin, - "test.cpp" - ) - val lambdaName = "0" - val lambdaFullName = s"A.B.Foo.$lambdaName" - val signature = s"int(int,int)" - - cpg.member.name("x").order.l shouldBe List(1) - - inside(cpg.assignment.l) { case List(assignment1) => - inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => - ref.methodFullName shouldBe lambdaFullName - } - } - - inside(cpg.method.fullNameExact(lambdaFullName).isLambda.l) { case List(l1) => - l1.name shouldBe lambdaName - l1.code should startWith("[] (int a, int b) -> int") - l1.signature shouldBe signature - } - - inside(cpg.typeDecl.fullNameExact("A.B.Foo").head.bindsOut.l) { case List(binding: Binding) => - binding.name shouldBe lambdaName - binding.signature shouldBe signature - inside(binding.refOut.l) { case List(method: Method) => - method.name shouldBe lambdaName - method.fullName shouldBe lambdaFullName - method.signature shouldBe signature - } - } - } - - "be correct when calling a lambda" in { - val cpg = code( - """ - |auto x = [](int n) -> int - |{ - | return 32 + n; - |}; - | - |constexpr int foo1 = x(10); - |constexpr int foo2 = [](int n) -> int - |{ - | return 32 + n; - |}(10); - |""".stripMargin, - "test.cpp" - ) - val lambda1Name = "0" - val signature1 = s"int(int)" - val lambda2Name = "1" - val signature2 = s"int(int)" - - cpg.local.name("x").order.l shouldBe List(1) - cpg.local.name("foo1").order.l shouldBe List(3) - cpg.local.name("foo2").order.l shouldBe List(5) - - inside(cpg.assignment.l) { case List(assignment1, assignment2, assignment3) => - assignment1.order shouldBe 2 - assignment2.order shouldBe 4 - assignment3.order shouldBe 6 - inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => - ref.methodFullName shouldBe lambda1Name - } - } - - inside(cpg.method.fullNameExact(lambda1Name).isLambda.l) { case List(l1) => - l1.name shouldBe lambda1Name - l1.code should startWith("[](int n) -> int") - l1.signature shouldBe signature1 - } - - inside(cpg.typeDecl(NamespaceTraversal.globalNamespaceName).head.bindsOut.l) { - case List(b1: Binding, b2: Binding) => - b1.name shouldBe lambda1Name - b1.signature shouldBe signature1 - inside(b1.refOut.l) { case List(method: Method) => - method.name shouldBe lambda1Name - method.fullName shouldBe lambda1Name - method.signature shouldBe signature1 - } - b2.name shouldBe lambda2Name - b2.signature shouldBe signature2 - inside(b2.refOut.l) { case List(method: Method) => - method.name shouldBe lambda2Name - method.fullName shouldBe lambda2Name - method.signature shouldBe signature2 - } - } - - inside(cpg.call.nameExact("()").l) { case List(lambda1call, lambda2call) => - lambda1call.name shouldBe "()" - lambda1call.methodFullName shouldBe "():int(int)" - lambda1call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH - inside(lambda1call.astChildren.l) { case List(id: Identifier, lit: Literal) => - id.code shouldBe "x" - lit.code shouldBe "10" - } - inside(lambda1call.argument.l) { case List(lit: Literal) => - lit.code shouldBe "10" - } - inside(lambda1call.receiver.l) { case List(receiver: Identifier) => - receiver.code shouldBe "x" - } - - lambda2call.name shouldBe "()" - lambda2call.methodFullName shouldBe "():int(int)" - lambda2call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH - inside(lambda2call.astChildren.l) { case List(ref: MethodRef, lit: Literal) => - ref.methodFullName shouldBe lambda2Name - ref.code should startWith("[](int n) -> int") - lit.code shouldBe "10" - } - - inside(lambda2call.argument.l) { case List(lit: Literal) => - lit.code shouldBe "10" - } - inside(lambda2call.receiver.l) { case List(ref: MethodRef) => - ref.methodFullName shouldBe lambda2Name - ref.code should startWith("[](int n) -> int") - } - } - } - "be correct for empty method" in { val cpg = code("void method(int x) { }") - inside(cpg.method.name("method").astChildren.l) { + inside(cpg.method.nameExact("method").astChildren.l) { case List(param: MethodParameterIn, _: Block, ret: MethodReturn) => ret.typeFullName shouldBe "void" param.typeFullName shouldBe "int" @@ -354,7 +138,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.typeFullName shouldBe "a_struct_type*" param.name shouldBe "a_struct" param.code shouldBe "a_struct_type *a_struct" @@ -369,7 +153,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.code shouldBe "struct date *date" param.typeFullName shouldBe "date*" param.name shouldBe "date" @@ -384,7 +168,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.typeFullName shouldBe "int[]" param.name shouldBe "x" } @@ -398,7 +182,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.typeFullName shouldBe "int[]" param.name shouldBe "" } @@ -412,7 +196,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.typeFullName shouldBe "a_struct_type[]" param.name shouldBe "a_struct" } @@ -426,7 +210,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | free(x); |} |""".stripMargin) - inside(cpg.method.name("method").parameter.l) { case List(param: MethodParameterIn) => + inside(cpg.method.nameExact("method").parameter.l) { case List(param: MethodParameterIn) => param.typeFullName shouldBe "a_struct_type[]*" param.name shouldBe "a_struct" } @@ -438,7 +222,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | int local = 1; |} |""".stripMargin) - inside(cpg.method.name("method").block.astChildren.l) { case List(local: Local, call: Call) => + inside(cpg.method.nameExact("method").block.astChildren.l) { case List(local: Local, call: Call) => local.name shouldBe "local" local.typeFullName shouldBe "int" local.order shouldBe 1 @@ -467,7 +251,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "test.cpp" ) - inside(cpg.method.name("method").block.astChildren.l) { case List(_, call1: Call, _, call2: Call) => + inside(cpg.method.nameExact("method").block.astChildren.l) { case List(_, call1: Call, _, call2: Call) => call1.name shouldBe Operators.assignment inside(call2.astChildren.l) { case List(identifier: Identifier, call: Call) => identifier.name shouldBe "is_std_array_v" @@ -492,7 +276,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |void method(int x) { | int local = x; |}""".stripMargin) - cpg.local.name("local").order.l shouldBe List(1) + cpg.local.nameExact("local").order.l shouldBe List(1) inside(cpg.method("method").block.astChildren.assignment.source.l) { case List(identifier: Identifier) => identifier.code shouldBe "x" identifier.typeFullName shouldBe "int" @@ -501,6 +285,20 @@ class AstCreationPassTests extends AstC2CpgSuite { } } + "be correct for decl assignment with references" in { + val cpg = code( + """ + |int addrOfLocalRef(struct x **foo) { + | struct x &bar = **foo; + | *foo = &bar; + |}""".stripMargin, + "foo.cc" + ) + val List(barLocal) = cpg.method.nameExact("addrOfLocalRef").local.l + barLocal.name shouldBe "bar" + barLocal.code shouldBe "struct x &bar" + } + "be correct for decl assignment of multiple locals" in { val cpg = code(""" |void method(int x, int y) { @@ -544,7 +342,7 @@ class AstCreationPassTests extends AstC2CpgSuite { val localZ = cpg.local.order(3) localZ.name.l shouldBe List("z") - inside(cpg.method.name("method").ast.isCall.name(Operators.assignment).cast[OpNodes.Assignment].l) { + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.assignment).cast[OpNodes.Assignment].l) { case List(assignment) => assignment.target.code shouldBe "x" assignment.source.start.isCall.name.l shouldBe List(Operators.addition) @@ -566,7 +364,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | } |} """.stripMargin) - inside(cpg.method.name("method").block.astChildren.l) { case List(local: Local, innerBlock: Block) => + inside(cpg.method.nameExact("method").block.astChildren.l) { case List(local: Local, innerBlock: Block) => local.name shouldBe "x" local.order shouldBe 1 inside(innerBlock.astChildren.l) { case List(localInBlock: Local) => @@ -584,7 +382,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | } |} """.stripMargin) - inside(cpg.method.name("method").block.astChildren.isControlStructure.l) { + inside(cpg.method.nameExact("method").block.astChildren.isControlStructure.l) { case List(controlStruct: ControlStructure) => controlStruct.code shouldBe "while (x < 1)" controlStruct.controlStructureType shouldBe ControlStructureTypes.WHILE @@ -601,13 +399,11 @@ class AstCreationPassTests extends AstC2CpgSuite { val cpg = code(""" |void method(int x) { | int y; - | if (x > 0) { - | y = 0; - | } + | if (x > 0) { y = 0; } |} """.stripMargin) - inside(cpg.method.name("method").controlStructure.l) { case List(controlStruct: ControlStructure) => - controlStruct.code shouldBe "if (x > 0)" + inside(cpg.method.nameExact("method").controlStructure.l) { case List(controlStruct: ControlStructure) => + controlStruct.code shouldBe "if (x > 0) { y = 0; }" controlStruct.controlStructureType shouldBe ControlStructureTypes.IF inside(controlStruct.condition.l) { case List(cndNode) => cndNode.code shouldBe "x > 0" @@ -621,16 +417,12 @@ class AstCreationPassTests extends AstC2CpgSuite { val cpg = code(""" |void method(int x) { | int y; - | if (x > 0) { - | y = 0; - | } else { - | y = 1; - | } + | if (x > 0) { y = 0; } else { y = 1; } |} """.stripMargin) - inside(cpg.method.name("method").controlStructure.l) { case List(ifStmt, elseStmt) => + inside(cpg.method.nameExact("method").controlStructure.l) { case List(ifStmt, elseStmt) => ifStmt.controlStructureType shouldBe ControlStructureTypes.IF - ifStmt.code shouldBe "if (x > 0)" + ifStmt.code shouldBe "if (x > 0) { y = 0; } else { y = 1; }" elseStmt.controlStructureType shouldBe ControlStructureTypes.ELSE elseStmt.code shouldBe "else" @@ -653,7 +445,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | int x = (true ? vlc_dccp_CreateFD : vlc_datagram_CreateFD)(fd); | } """.stripMargin) - inside(cpg.method.name("method").ast.isCall.name(Operators.conditional).l) { case List(call) => + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.conditional).l) { case List(call) => call.code shouldBe "true ? vlc_dccp_CreateFD : vlc_datagram_CreateFD" } } @@ -668,7 +460,7 @@ class AstCreationPassTests extends AstC2CpgSuite { // `cpg.method.call` will not work at this stage // either because there are no CONTAINS edges - inside(cpg.method.name("method").ast.isCall.name(Operators.conditional).l) { case List(call) => + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.conditional).l) { case List(call) => call.code shouldBe "(foo == 1) ? bar : 0" inside(call.argument.l) { case List(condition, trueBranch, falseBranch) => condition.argumentIndex shouldBe 1 @@ -691,17 +483,18 @@ class AstCreationPassTests extends AstC2CpgSuite { |}""".stripMargin, "file.cpp" ) - inside(cpg.method.name("method").controlStructure.l) { case List(forStmt) => + inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) => forStmt.controlStructureType shouldBe ControlStructureTypes.FOR - inside(forStmt.astChildren.order(1).l) { case List(ident: Identifier) => - ident.code shouldBe "list" - } - inside(forStmt.astChildren.order(2).l) { case List(x: Local) => + inside(forStmt.astChildren.isLocal.l) { case List(x: Local) => x.name shouldBe "x" x.typeFullName shouldBe "int" x.code shouldBe "int x" } - inside(forStmt.astChildren.order(3).l) { case List(block: Block) => + // for the expected orders see CfgCreator.cfgForForStatement + inside(forStmt.astChildren.order(2).l) { case List(ident: Identifier) => + ident.code shouldBe "list" + } + inside(forStmt.astChildren.order(5).l) { case List(block: Block) => block.astChildren.isCall.code.l shouldBe List("z = x") } } @@ -717,19 +510,27 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "test.cpp" ) - inside(cpg.method.name("method").controlStructure.l) { case List(forStmt) => + inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) => forStmt.controlStructureType shouldBe ControlStructureTypes.FOR - inside(forStmt.astChildren.order(1).l) { case List(ident) => - ident.code shouldBe "foo" - } - inside(forStmt.astChildren.order(2).astChildren.l) { case List(idA, idB) => - idA.code shouldBe "a" - idB.code shouldBe "b" - } - inside(forStmt.astChildren.order(3).l) { case List(block) => - block.code shouldBe "" - block.astChildren.l shouldBe empty - } + forStmt.astChildren.isBlock.astChildren.isCall.code.l shouldBe List( + "anonymous_tmp_0 = foo", + "a = anonymous_tmp_0[0]", + "b = anonymous_tmp_0[1]" + ) + } + cpg.local.map { l => (l.name, l.typeFullName) }.toMap shouldBe Map( + "foo" -> "int[2]", + "anonymous_tmp_0" -> "int[2]", + "a" -> "ANY", + "b" -> "ANY" + ) + pendingUntilFixed { + cpg.local.map { l => (l.name, l.typeFullName) }.toMap shouldBe Map( + "foo" -> "int[2]", + "anonymous_tmp_0" -> "int[2]", + "a" -> "int*", + "b" -> "int*" + ) } } @@ -741,7 +542,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | } |} """.stripMargin) - inside(cpg.method.name("method").controlStructure.l) { case List(forStmt) => + inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) => forStmt.controlStructureType shouldBe ControlStructureTypes.FOR childContainsAssignments(forStmt, 1, List("x = 0", "y = 0")) @@ -768,10 +569,10 @@ class AstCreationPassTests extends AstC2CpgSuite { |} """.stripMargin) cpg.method - .name("method") + .nameExact("method") .ast .isCall - .name(Operators.preIncrement) + .nameExact(Operators.preIncrement) .argument(1) .code .l shouldBe List("x") @@ -804,7 +605,7 @@ class AstCreationPassTests extends AstC2CpgSuite { """.stripMargin) val List(forLoop) = cpg.controlStructure.l val List(conditionBlock) = forLoop.condition.collectAll[Block].l - conditionBlock.argumentIndex shouldBe 2 + conditionBlock.order shouldBe 2 val List(assignmentCall, greaterCall) = conditionBlock.astChildren.collectAll[Call].l assignmentCall.argumentIndex shouldBe 1 assignmentCall.code shouldBe "b = something()" @@ -819,10 +620,10 @@ class AstCreationPassTests extends AstC2CpgSuite { |} """.stripMargin) cpg.method - .name("method") + .nameExact("method") .ast .isCall - .name("foo") + .nameExact("foo") .argument(1) .code .l shouldBe List("x") @@ -835,7 +636,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | foo(x); |} """.stripMargin) - inside(cpg.method.name("method").ast.isCall.l) { case List(call: Call) => + inside(cpg.method.nameExact("method").ast.isCall.l) { case List(call: Call) => call.code shouldBe "foo(x)" call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH val rec = call.receiver.l @@ -850,7 +651,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | x.a; |} """.stripMargin) - inside(cpg.method.name("method").ast.isCall.name(Operators.fieldAccess).l) { case List(call) => + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.fieldAccess).l) { case List(call) => val arg1 = call.argument(1) val arg2 = call.argument(2) arg1.isIdentifier shouldBe true @@ -869,7 +670,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | x->a; |} """.stripMargin) - inside(cpg.method.name("method").ast.isCall.name(Operators.indirectFieldAccess).l) { case List(call) => + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.indirectFieldAccess).l) { case List(call) => val arg1 = call.argument(1) val arg2 = call.argument(2) arg1.isIdentifier shouldBe true @@ -888,7 +689,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return (x->a)(1, 2); |} """.stripMargin) - inside(cpg.method.name("method").ast.isCall.name(Operators.indirectFieldAccess).l) { case List(call) => + inside(cpg.method.nameExact("method").ast.isCall.nameExact(Operators.indirectFieldAccess).l) { case List(call) => val arg1 = call.argument(1) val arg2 = call.argument(2) arg1.isIdentifier shouldBe true @@ -909,9 +710,9 @@ class AstCreationPassTests extends AstC2CpgSuite { | return (*strLenFunc)("123"); |} """.stripMargin) - inside(cpg.method.name("main").ast.isCall.codeExact("(*strLenFunc)(\"123\")").l) { case List(call) => - call.name shouldBe Defines.operatorPointerCall - call.methodFullName shouldBe Defines.operatorPointerCall + inside(cpg.method.nameExact("main").ast.isCall.codeExact("(*strLenFunc)(\"123\")").l) { case List(call) => + call.name shouldBe Defines.OperatorPointerCall + call.methodFullName shouldBe Defines.OperatorPointerCall } } @@ -923,13 +724,13 @@ class AstCreationPassTests extends AstC2CpgSuite { |} """.stripMargin) cpg.method - .name("method") + .nameExact("method") .ast .isCall - .name(Operators.sizeOf) + .nameExact(Operators.sizeOf) .argument(1) .isIdentifier - .name("a") + .nameExact("a") .argumentIndex(1) .size shouldBe 1 } @@ -942,13 +743,13 @@ class AstCreationPassTests extends AstC2CpgSuite { |} """.stripMargin) cpg.method - .name("method") + .nameExact("method") .ast .isCall - .name(Operators.sizeOf) + .nameExact(Operators.sizeOf) .argument(1) .isIdentifier - .name("a") + .nameExact("a") .argumentIndex(1) .size shouldBe 1 } @@ -962,13 +763,13 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) cpg.method - .name("method") + .nameExact("method") .ast .isCall - .name(Operators.sizeOf) + .nameExact(Operators.sizeOf) .argument(1) .isIdentifier - .name("int") + .nameExact("int") .argumentIndex(1) .size shouldBe 1 } @@ -981,7 +782,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | void method() { | }; """.stripMargin) - cpg.method.name("method").size shouldBe 1 + cpg.method.nameExact("method").size shouldBe 1 } "be correct for empty named struct" in { @@ -989,14 +790,14 @@ class AstCreationPassTests extends AstC2CpgSuite { | struct foo { | }; """.stripMargin) - cpg.typeDecl.name("foo").size shouldBe 1 + cpg.typeDecl.nameExact("foo").size shouldBe 1 } "be correct for struct decl" in { val cpg = code(""" | struct foo; """.stripMargin) - cpg.typeDecl.name("foo").size shouldBe 1 + cpg.typeDecl.nameExact("foo").size shouldBe 1 } "be correct for named struct with single field" in { @@ -1006,10 +807,10 @@ class AstCreationPassTests extends AstC2CpgSuite { | }; """.stripMargin) cpg.typeDecl - .name("foo") + .nameExact("foo") .member .code("x") - .name("x") + .nameExact("x") .typeFullName("int") .size shouldBe 1 } @@ -1022,7 +823,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | int z; | }; """.stripMargin) - cpg.typeDecl.name("foo").member.code.toSetMutable shouldBe Set("x", "y", "z") + cpg.typeDecl.nameExact("foo").member.code.toSetMutable shouldBe Set("x", "y", "z") } "be correct for named struct with nested struct" in { @@ -1037,12 +838,12 @@ class AstCreationPassTests extends AstC2CpgSuite { | }; | }; """.stripMargin) - inside(cpg.typeDecl.name("foo").l) { case List(fooStruct: TypeDecl) => - fooStruct.member.name("x").size shouldBe 1 + inside(cpg.typeDecl.nameExact("foo").l) { case List(fooStruct: TypeDecl) => + fooStruct.member.nameExact("x").size shouldBe 1 inside(fooStruct.astChildren.isTypeDecl.l) { case List(barStruct: TypeDecl) => - barStruct.member.name("y").size shouldBe 1 + barStruct.member.nameExact("y").size shouldBe 1 inside(barStruct.astChildren.isTypeDecl.l) { case List(foo2Struct: TypeDecl) => - foo2Struct.member.name("z").size shouldBe 1 + foo2Struct.member.nameExact("z").size shouldBe 1 } } } @@ -1053,7 +854,14 @@ class AstCreationPassTests extends AstC2CpgSuite { |typedef struct foo { |} abc; """.stripMargin) - cpg.typeDecl.name("foo").aliasTypeFullName("abc").size shouldBe 1 + cpg.typeDecl.nameExact("foo").aliasTypeFullName("abc").size shouldBe 1 + } + + "be correct for anonymous typedef struct" in { + val cpg = code("typedef struct { int m; } t;", "t.cpp") + val List(t) = cpg.typeDecl.nameExact("t").l + cpg.typeDecl.nameExact("ANY").size shouldBe 0 + t.aliasTypeFullName.size shouldBe 0 // no alias for named anonymous typedefs } "be correct for struct with local" in { @@ -1067,11 +875,11 @@ class AstCreationPassTests extends AstC2CpgSuite { x.name shouldBe "x" x.typeFullName shouldBe "int" } - cpg.typeDecl.name("B").size shouldBe 1 + cpg.typeDecl.nameExact("B").size shouldBe 1 inside(cpg.local.l) { case List(localA, localB) => localA.name shouldBe "a" localA.typeFullName shouldBe "A" - localA.code shouldBe "struct A a" + localA.code shouldBe "struct A { int x; } a" localB.name shouldBe "b" localB.typeFullName shouldBe "B" localB.code shouldBe "struct B b" @@ -1104,10 +912,10 @@ class AstCreationPassTests extends AstC2CpgSuite { | i = 0; |} """.stripMargin) - val List(localMyOtherFs) = cpg.method("main").local.name("my_other_fs").l + val List(localMyOtherFs) = cpg.method("main").local.nameExact("my_other_fs").l localMyOtherFs.order shouldBe 2 localMyOtherFs.referencingIdentifiers.name.l shouldBe List("my_other_fs") - val List(localMyFs) = cpg.local.name("my_fs").l + val List(localMyFs) = cpg.local.nameExact("my_fs").l localMyFs.order shouldBe 4 localMyFs.referencingIdentifiers.name.l shouldBe List("my_fs") cpg.typeDecl.nameNot(NamespaceTraversal.globalNamespaceName).fullName.l.distinct shouldBe List("filesystem") @@ -1118,7 +926,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |typedef enum foo { |} abc; """.stripMargin) - cpg.typeDecl.name("foo").aliasTypeFullName("abc").size shouldBe 1 + cpg.typeDecl.nameExact("foo").aliasTypeFullName("abc").size shouldBe 1 } "be correct for classes with friends" in { @@ -1150,7 +958,7 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) cpg.typeDecl - .name("Derived") + .nameExact("Derived") .count(_.inheritsFromTypeFullName == List("Base")) shouldBe 1 } @@ -1161,7 +969,7 @@ class AstCreationPassTests extends AstC2CpgSuite { """.stripMargin, "file.cpp" ) - inside(cpg.call.name(Operators.cast).l) { case List(call: Call) => + inside(cpg.call.nameExact(Operators.cast).l) { case List(call: Call) => call.argument(2).code shouldBe "{ 1 }" call.argument(1).code shouldBe "int" } @@ -1300,7 +1108,7 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) cpg.typeDecl - .name("Y") + .nameExact("Y") .l .size shouldBe 1 } @@ -1319,7 +1127,7 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) cpg.method - .name("f") + .nameExact("f") .l .size shouldBe 1 } @@ -1352,10 +1160,10 @@ class AstCreationPassTests extends AstC2CpgSuite { |} |""".stripMargin) cpg.method - .name("foo") + .nameExact("foo") .ast .isCall - .name("bar") + .nameExact("bar") .argument .code("x") .size shouldBe 1 @@ -1368,16 +1176,14 @@ class AstCreationPassTests extends AstC2CpgSuite { |} |""".stripMargin) // TODO no step class defined for `Return` nodes - cpg.method.name("d").ast.isReturn.astChildren.order(1).isCall.code.l shouldBe List("x * 2") + cpg.method.nameExact("d").ast.isReturn.astChildren.order(1).isCall.code.l shouldBe List("x * 2") cpg.method - .name("d") + .nameExact("d") .ast .isReturn - .outE(EdgeTypes.ARGUMENT) + .out(EdgeTypes.ARGUMENT) .head - .inNode() - .get - .asInstanceOf[CallDb] + .asInstanceOf[Call] .code shouldBe "x * 2" } @@ -1387,7 +1193,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return x * 2; |} |""".stripMargin) - cpg.call.name(Operators.multiplication).code.l shouldBe List("x * 2") + cpg.call.nameExact(Operators.multiplication).code.l shouldBe List("x * 2") } "be correct for unary method calls" in { @@ -1396,7 +1202,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return !b; |} |""".stripMargin) - cpg.call.name(Operators.logicalNot).argument(1).code.l shouldBe List("b") + cpg.call.nameExact(Operators.logicalNot).argument(1).code.l shouldBe List("b") } "be correct for unary expr" in { @@ -1407,7 +1213,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return end ? (int)(end - str) : max; | } |""".stripMargin) - inside(cpg.call.name(Operators.cast).astChildren.l) { case List(tpe: Unknown, call: Call) => + inside(cpg.call.nameExact(Operators.cast).astChildren.l) { case List(tpe: Unknown, call: Call) => call.code shouldBe "end - str" call.argumentIndex shouldBe 2 tpe.code shouldBe "int" @@ -1423,8 +1229,8 @@ class AstCreationPassTests extends AstC2CpgSuite { | return pos; |} |""".stripMargin) - cpg.call.name(Operators.postIncrement).argument(1).code("x").size shouldBe 1 - cpg.call.name(Operators.postDecrement).argument(1).code("x").size shouldBe 1 + cpg.call.nameExact(Operators.postIncrement).argument(1).code("x").size shouldBe 1 + cpg.call.nameExact(Operators.postDecrement).argument(1).code("x").size shouldBe 1 } "be correct for conditional expressions containing calls" in { @@ -1433,7 +1239,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return x > 0 ? x : -x; |} |""".stripMargin) - cpg.call.name(Operators.conditional).argument.code.l shouldBe List("x > 0", "x", "-x") + cpg.call.nameExact(Operators.conditional).argument.code.l shouldBe List("x > 0", "x", "-x") } "be correct for sizeof expressions" in { @@ -1442,7 +1248,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return sizeof(int); |} |""".stripMargin) - inside(cpg.call.name(Operators.sizeOf).argument(1).l) { case List(i: Identifier) => + inside(cpg.call.nameExact(Operators.sizeOf).argument(1).l) { case List(i: Identifier) => i.code shouldBe "int" i.name shouldBe "int" } @@ -1459,7 +1265,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return x[0]; |} |""".stripMargin) - cpg.call.name(Operators.indirectIndexAccess).argument.code.l shouldBe List("x", "0") + cpg.call.nameExact(Operators.indirectIndexAccess).argument.code.l shouldBe List("x", "0") } "be correct for type casts" in { @@ -1468,7 +1274,7 @@ class AstCreationPassTests extends AstC2CpgSuite { | return (int) x; |} |""".stripMargin) - cpg.call.name(Operators.cast).argument.code.l shouldBe List("int", "x") + cpg.call.nameExact(Operators.cast).argument.code.l shouldBe List("int", "x") } "be correct for 'new' array" in { @@ -1482,7 +1288,7 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) // TODO: ".new" is not part of Operators - cpg.call.name(".new").code("new int\\[n\\]").argument.code("int").size shouldBe 1 + cpg.call.nameExact(".new").code("new int\\[n\\]").argument.code("int").size shouldBe 1 } "be correct for 'new' with explicit identifier" in { @@ -1496,7 +1302,7 @@ class AstCreationPassTests extends AstC2CpgSuite { "file.cpp" ) // TODO: ".new" is not part of Operators - val List(newCall) = cpg.call.name(".new").l + val List(newCall) = cpg.call.nameExact(".new").l val List(string, hi, buf) = newCall.argument.l string.argumentIndex shouldBe 1 string.code shouldBe "string" @@ -1510,17 +1316,61 @@ class AstCreationPassTests extends AstC2CpgSuite { "be correct for array size" in { val cpg = code(""" |int main() { - | char buf[256]; - | printf("%s", buf); + | char bufA[256]; + | char bufB[1+2]; |} |""".stripMargin) - inside(cpg.local.l) { case List(buf: Local) => - buf.typeFullName shouldBe "char[256]" - buf.name shouldBe "buf" - buf.code shouldBe "char[256] buf" + inside(cpg.call.nameExact(Operators.assignment).l) { case List(bufCallAAssign: Call, bufCallBAssign: Call) => + val List(bufAId, bufCallA) = bufCallAAssign.argument.l + bufAId.code shouldBe "bufA" + val List(bufBId, bufCallB) = bufCallBAssign.argument.l + bufBId.code shouldBe "bufB" + + inside(cpg.call.nameExact(Operators.alloc).l) { case List(bufCallAAlloc: Call, bufCallBAlloc: Call) => + bufCallAAlloc shouldBe bufCallA + bufCallBAlloc shouldBe bufCallB + + bufCallAAlloc.code shouldBe "bufA[256]" + bufCallAAlloc.typeFullName shouldBe "char[256]" + val List(argA) = bufCallAAlloc.argument.isLiteral.l + argA.code shouldBe "256" + + bufCallBAlloc.code shouldBe "bufB[1+2]" + bufCallBAlloc.typeFullName shouldBe "char[1+2]" + val List(argB) = bufCallBAlloc.argument.isCall.l + argB.name shouldBe Operators.addition + argB.code shouldBe "1+2" + val List(one, two) = argB.argument.isLiteral.l + one.code shouldBe "1" + two.code shouldBe "2" + } + } + + inside(cpg.local.l) { case List(bufA: Local, bufB: Local) => + bufA.typeFullName shouldBe "char[256]" + bufA.name shouldBe "bufA" + bufA.code shouldBe "char bufA[256]" + + bufB.typeFullName shouldBe "char[1+2]" + bufB.name shouldBe "bufB" + bufB.code shouldBe "char bufB[1+2]" } } + "be correct for empty array init" in { + val cpg = code(""" + |void other(void) { + | int i = 0; + | char str[] = "abc"; + | printf("%d %s", i, str); + |} + |""".stripMargin) + val List(str1, str2) = cpg.identifier.nameExact("str").l + str1.typeFullName shouldBe "char[]" + str2.typeFullName shouldBe "char[]" + cpg.call.nameExact(Operators.alloc) shouldBe empty + } + "be correct for array init" in { val cpg = code(""" |int x[] = {0, 1, 2, 3}; @@ -1658,6 +1508,27 @@ class AstCreationPassTests extends AstC2CpgSuite { } } + "be correct for method refs from function pointers" in { + val cpg = code(""" + |uid_t getuid(void); + |void someFunction() {} + |void checkFunctionPointerComparison() { + | if (getuid == 0 || someFunction == 0) {} + |} + |""".stripMargin) + val List(methodA) = cpg.method.fullNameExact("getuid").l + val List(methodB) = cpg.method.fullNameExact("someFunction").l + cpg.method.fullNameExact("checkFunctionPointerComparison").size shouldBe 1 + inside(cpg.call.nameExact(Operators.equals).l) { case List(callA: Call, callB: Call) => + val getuidRef = callA.argument(1).asInstanceOf[MethodRef] + getuidRef.methodFullName shouldBe methodA.fullName + getuidRef.typeFullName shouldBe methodA.methodReturn.typeFullName + val someFunctionRef = callB.argument(1).asInstanceOf[MethodRef] + someFunctionRef.methodFullName shouldBe methodB.fullName + someFunctionRef.typeFullName shouldBe methodB.methodReturn.typeFullName + } + } + "be correct for locals for array init" in { val cpg = code(""" |bool x[2] = { TRUE, FALSE }; @@ -1738,7 +1609,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(".new").codeExact("new Foo(n, 42)").argument.code("Foo").size shouldBe 1 + cpg.call.nameExact(".new").codeExact("new Foo(n, 42)").argument.code("Foo").size shouldBe 1 } "be correct for simple 'delete'" in { @@ -1750,7 +1621,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.delete).code("delete n").argument.code("n").size shouldBe 1 + cpg.call.nameExact(Operators.delete).code("delete n").argument.code("n").size shouldBe 1 } "be correct for array 'delete'" in { @@ -1762,7 +1633,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.delete).codeExact("delete[] n").argument.code("n").size shouldBe 1 + cpg.call.nameExact(Operators.delete).codeExact("delete[] n").argument.code("n").size shouldBe 1 } "be correct for const_cast" in { @@ -1775,7 +1646,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.cast).codeExact("const_cast(n)").argument.code.l shouldBe List("int", "n") + cpg.call.nameExact(Operators.cast).codeExact("const_cast(n)").argument.code.l shouldBe List("int", "n") } "be correct for static_cast" in { @@ -1788,7 +1659,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.cast).codeExact("static_cast(n)").argument.code.l shouldBe List("int", "n") + cpg.call.nameExact(Operators.cast).codeExact("static_cast(n)").argument.code.l shouldBe List("int", "n") } "be correct for dynamic_cast" in { @@ -1801,7 +1672,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.cast).codeExact("dynamic_cast(n)").argument.code.l shouldBe List("int", "n") + cpg.call.nameExact(Operators.cast).codeExact("dynamic_cast(n)").argument.code.l shouldBe List("int", "n") } "be correct for reinterpret_cast" in { @@ -1814,7 +1685,7 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin, "file.cpp" ) - cpg.call.name(Operators.cast).codeExact("reinterpret_cast(n)").argument.code.l shouldBe List("int", "n") + cpg.call.nameExact(Operators.cast).codeExact("reinterpret_cast(n)").argument.code.l shouldBe List("int", "n") } "be correct for designated initializers in plain C" in { @@ -1823,14 +1694,14 @@ class AstCreationPassTests extends AstC2CpgSuite { | int a[3] = { [1] = 5, [2] = 10, [3 ... 9] = 15 }; |}; """.stripMargin) - inside(cpg.assignment.head.astChildren.l) { case List(ident: Identifier, call: Call) => + inside(cpg.assignment.l(1).astChildren.l) { case List(ident: Identifier, call: Call) => ident.typeFullName shouldBe "int[3]" ident.order shouldBe 1 call.code shouldBe "{ [1] = 5, [2] = 10, [3 ... 9] = 15 }" call.order shouldBe 2 call.name shouldBe Operators.arrayInitializer call.methodFullName shouldBe Operators.arrayInitializer - val children = call.astMinusRoot.isCall.name(Operators.assignment).l + val children = call.astMinusRoot.isCall.nameExact(Operators.assignment).l val args = call.argument.astChildren.l inside(children) { case List(call1, call2, call3) => call1.code shouldBe "[1] = 5" @@ -1863,14 +1734,14 @@ class AstCreationPassTests extends AstC2CpgSuite { """.stripMargin, "test.cpp" ) - inside(cpg.assignment.head.astChildren.l) { case List(ident: Identifier, call: Call) => + inside(cpg.assignment.l(1).astChildren.l) { case List(ident: Identifier, call: Call) => ident.typeFullName shouldBe "int[3]" ident.order shouldBe 1 call.code shouldBe "{ [1] = 5, [2] = 10, [3 ... 9] = 15 }" call.order shouldBe 2 call.name shouldBe Operators.arrayInitializer call.methodFullName shouldBe Operators.arrayInitializer - val children = call.astMinusRoot.isCall.name(Operators.assignment).l + val children = call.astMinusRoot.isCall.nameExact(Operators.assignment).l val args = call.argument.astChildren.l inside(children) { case List(call1, call2, call3) => call1.code shouldBe "[1] = 5" @@ -1927,10 +1798,10 @@ class AstCreationPassTests extends AstC2CpgSuite { val cpg = code( """ |class Point3D { - | public: - | int x; - | int y; - | int z; + | public: + | int x; + | int y; + | int z; |}; | |void foo() { @@ -1939,32 +1810,7 @@ class AstCreationPassTests extends AstC2CpgSuite { """.stripMargin, "test.cpp" ) - inside(cpg.call.code("point3D \\{ .x = 1, .y = 2, .z = 3 \\}").l) { case List(call: Call) => - call.name shouldBe "point3D" - call.methodFullName shouldBe "point3D" - inside(call.astChildren.l) { case List(initCall: Call) => - initCall.code shouldBe "{ .x = 1, .y = 2, .z = 3 }" - initCall.name shouldBe Operators.arrayInitializer - initCall.methodFullName shouldBe Operators.arrayInitializer - val children = initCall.astMinusRoot.isCall.l - val args = initCall.argument.astChildren.l - inside(children) { case List(call1, call2, call3) => - call1.code shouldBe ".x = 1" - call1.name shouldBe Operators.assignment - call1.astMinusRoot.code.l shouldBe List("x", "1") - call1.argument.code.l shouldBe List("x", "1") - call2.code shouldBe ".y = 2" - call2.name shouldBe Operators.assignment - call2.astMinusRoot.code.l shouldBe List("y", "2") - call2.argument.code.l shouldBe List("y", "2") - call3.code shouldBe ".z = 3" - call3.name shouldBe Operators.assignment - call3.astMinusRoot.code.l shouldBe List("z", "3") - call3.argument.code.l shouldBe List("z", "3") - } - children shouldBe args - } - } + cpg.assignment.code.sorted.l shouldBe List("point3D.x = 1", "point3D.y = 2", "point3D.z = 3") } "be correct for call with pack expansion" in { @@ -2047,8 +1893,8 @@ class AstCreationPassTests extends AstC2CpgSuite { | x = 1; | } """.stripMargin) - cpg.method.name("method").lineNumber.l shouldBe List(6) - cpg.method.name("method").block.assignment.lineNumber.l shouldBe List(8) + cpg.method.nameExact("method").lineNumber.l shouldBe List(6) + cpg.method.nameExact("method").block.assignment.lineNumber.l shouldBe List(8) } // for https://github.com/ShiftLeftSecurity/codepropertygraph/issues/1321 @@ -2123,9 +1969,9 @@ class AstCreationPassTests extends AstC2CpgSuite { val cpg = code("class Foo { char (*(*x())[5])() }", "test.cpp") val List(method) = cpg.method.nameNot("").l method.name shouldBe "x" - method.fullName shouldBe "Foo.x:char (* (*)[5])()()" + method.fullName shouldBe "Foo.x:char(*(*)[5])()()" method.code shouldBe "char (*(*x())[5])()" - method.signature shouldBe "char()" + method.signature shouldBe "char(*(*)[5])()()" } "be consistent with pointer types" in { @@ -2135,10 +1981,10 @@ class AstCreationPassTests extends AstC2CpgSuite { | char *x; |} |""".stripMargin) - cpg.member.name("z").typeFullName.head shouldBe "char*" - cpg.parameter.name("y").typeFullName.head shouldBe "char*" - cpg.local.name("x").typeFullName.head shouldBe "char*" - cpg.method.name("a").methodReturn.typeFullName.head shouldBe "char*" + cpg.member.nameExact("z").typeFullName.head shouldBe "char*" + cpg.parameter.nameExact("y").typeFullName.head shouldBe "char*" + cpg.local.nameExact("x").typeFullName.head shouldBe "char*" + cpg.method.nameExact("a").methodReturn.typeFullName.head shouldBe "char*" } "be consistent with array types" in { @@ -2148,9 +1994,9 @@ class AstCreationPassTests extends AstC2CpgSuite { | char x[1]; |} |""".stripMargin) - cpg.member.name("z").typeFullName.head shouldBe "char[1]" - cpg.parameter.name("y").typeFullName.head shouldBe "char[1]" - cpg.local.name("x").typeFullName.head shouldBe "char[1]" + cpg.member.nameExact("z").typeFullName.head shouldBe "char[1]" + cpg.parameter.nameExact("y").typeFullName.head shouldBe "char[1]" + cpg.local.nameExact("x").typeFullName.head shouldBe "char[1]" } "be consistent with long number types" in { @@ -2163,8 +2009,10 @@ class AstCreationPassTests extends AstC2CpgSuite { |""".stripMargin) val List(bufLocal) = cpg.local.nameExact("buf").l bufLocal.typeFullName shouldBe "char[0x111111111111111]" - bufLocal.code shouldBe "char[0x111111111111111] buf" - cpg.literal.code.l shouldBe List("0x111111111111111") + bufLocal.code shouldBe "char buf[BUFSIZE]" + val List(bufAllocCall) = cpg.call.nameExact(Operators.alloc).l + bufAllocCall.code shouldBe "buf[BUFSIZE]" + bufAllocCall.argument.ast.isLiteral.code.l shouldBe List("0x111111111111111") } } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallConventionsTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallConventionsTests.scala index 1adaddce886d..41f1ee51627c 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallConventionsTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallConventionsTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.AstC2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class CallConventionsTests extends AstC2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallTests.scala index a7efec6f1df8..b928b1b3d38d 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/CallTests.scala @@ -9,8 +9,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.Literal import io.shiftleft.semanticcpg.language.NoResolve import io.shiftleft.semanticcpg.language.* -import java.nio.file.{Files, Path} - class CallTests extends C2CpgSuite { implicit val resolver: NoResolve.type = NoResolve @@ -112,7 +110,7 @@ class CallTests extends C2CpgSuite { "have the correct callIn" in { val List(m) = cpg.method.nameNot("").where(_.ast.isReturn.code(".*nullptr.*")).l val List(c) = cpg.call.codeExact("b->GetObj()").l - c.callee.head shouldBe m + c.callee.l should contain(m) val List(callIn) = m.callIn.l callIn.code shouldBe "b->GetObj()" } @@ -341,9 +339,9 @@ class CallTests extends C2CpgSuite { "test.cpp" ) - val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + val List(call) = cpg.call.nameExact(Defines.OperatorPointerCall).l call.signature shouldBe "" - call.methodFullName shouldBe Defines.operatorPointerCall + call.methodFullName shouldBe Defines.OperatorPointerCall call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH call.typeFullName shouldBe "void" @@ -408,9 +406,9 @@ class CallTests extends C2CpgSuite { "test.c" ) - val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + val List(call) = cpg.call.nameExact(Defines.OperatorPointerCall).l call.signature shouldBe "" - call.methodFullName shouldBe Defines.operatorPointerCall + call.methodFullName shouldBe Defines.OperatorPointerCall call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH call.typeFullName shouldBe "void" @@ -562,9 +560,9 @@ class CallTests extends C2CpgSuite { "test.c" ) - val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + val List(call) = cpg.call.nameExact(Defines.OperatorPointerCall).l call.signature shouldBe "" - call.methodFullName shouldBe Defines.operatorPointerCall + call.methodFullName shouldBe Defines.OperatorPointerCall call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH call.typeFullName shouldBe X2CpgDefines.Any @@ -610,9 +608,9 @@ class CallTests extends C2CpgSuite { "test.c" ) - val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + val List(call) = cpg.call.nameExact(Defines.OperatorPointerCall).l call.signature shouldBe "" - call.methodFullName shouldBe Defines.operatorPointerCall + call.methodFullName shouldBe Defines.OperatorPointerCall call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH call.typeFullName shouldBe X2CpgDefines.Any diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ControlStructureTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ControlStructureTests.scala index b918bb66bca4..a46253eefc33 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ControlStructureTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ControlStructureTests.scala @@ -3,9 +3,10 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.parser.FileDefaults import io.joern.c2cpg.testfixtures.C2CpgSuite import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* +import org.apache.commons.lang3.StringUtils -class ControlStructureTests extends C2CpgSuite(FileDefaults.CPP_EXT) { +class ControlStructureTests extends C2CpgSuite(FileDefaults.CppExt) { "ControlStructureTest1" should { val cpg = code(""" @@ -50,7 +51,9 @@ class ControlStructureTests extends C2CpgSuite(FileDefaults.CPP_EXT) { } "should identify `switch` block" in { - cpg.method("foo").switchBlock.code.l shouldBe List("switch(y)") + cpg.method("foo").switchBlock.code.map(StringUtils.normalizeSpace).l shouldBe List( + "switch(y) { case 1: printf(\"bar\\n\"); break; default: }" + ) } "should identify `for` block" in { @@ -89,26 +92,26 @@ class ControlStructureTests extends C2CpgSuite(FileDefaults.CPP_EXT) { "should be correct for for-loop with multiple assignments" in { inside(cpg.controlStructure.l) { case List(forLoop) => forLoop.controlStructureType shouldBe ControlStructureTypes.FOR - inside(forLoop.astChildren.order(1).l) { case List(assignmentBlock) => - inside(assignmentBlock.astChildren.l) { case List(localX, localY, assignmentX, assignmentY) => - localX.code shouldBe "int x" - localX.order shouldBe 1 - localY.code shouldBe "int y" - localY.order shouldBe 2 + inside(forLoop.astChildren.isLocal.l) { case List(localX, localY) => + localX.code shouldBe "int x" + localY.code shouldBe "int y" + } + inside(forLoop.astChildren.order(3).l) { case List(assignmentBlock) => + inside(assignmentBlock.astChildren.l) { case List(assignmentX, assignmentY) => assignmentX.code shouldBe "x=1" - assignmentX.order shouldBe 3 + assignmentX.order shouldBe 1 assignmentY.code shouldBe "y=1" - assignmentY.order shouldBe 4 + assignmentY.order shouldBe 2 } } inside(forLoop.condition.l) { case List(x) => x.code shouldBe "x" - x.order shouldBe 2 + x.order shouldBe 4 } - inside(forLoop.astChildren.order(3).l) { case List(updateX) => + inside(forLoop.astChildren.order(5).l) { case List(updateX) => updateX.code shouldBe "--x" } - inside(forLoop.astChildren.order(4).l) { case List(loopBody) => + inside(forLoop.astChildren.order(6).l) { case List(loopBody) => loopBody.astChildren.isCall.head.code shouldBe "bar()" } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/DependencyTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/DependencyTests.scala index 5961a22bd2f1..ff9882703a7d 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/DependencyTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/DependencyTests.scala @@ -2,7 +2,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.AstC2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DependencyTests extends AstC2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/FileTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/FileTests.scala index 83da69ec968d..244250740568 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/FileTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/FileTests.scala @@ -65,6 +65,51 @@ class FileTests extends C2CpgSuite { } } + "File test for single file with a header include" should { + + val cpg = code( + """ + |#include "fetch.h" + |#include "cache.h" + |const char *write_ref = NULL; + |void pull_say(const char *fmt, const char *hex { + | if (get_verbosely) { fprintf(stderr, fmt, hex); } + |} + |""".stripMargin, + "fetch.c" + ) + + "contain the correct file nodes" in { + cpg.file.name.sorted.l shouldBe List("", "", "fetch.c") + } + + } + + "File test for single file with a header include that actually exists" should { + + val cpg = code( + """ + |#include "fetch.h" + |#include "cache.h" + |const char *write_ref = NULL; + |void pull_say(const char *fmt, const char *hex { + | if (get_verbosely) { fprintf(stderr, fmt, hex); } + |} + |""".stripMargin, + "fetch.c" + ).moreCode( + """ + |extern const char *write_ref; + |""".stripMargin, + "fetch.h" + ) + + "contain the correct file nodes" in { + cpg.file.name.sorted.l shouldBe List("", "", "fetch.c", "fetch.h") + } + + } + "File test for multiple source files and preprocessed files" should { val cpg = code("int foo() {}", "main.c") diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/HeaderAstCreationPassTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/HeaderAstCreationPassTests.scala index 53a624e68242..e09854f53372 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/HeaderAstCreationPassTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/HeaderAstCreationPassTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class HeaderAstCreationPassTests extends C2CpgSuite { @@ -37,20 +37,17 @@ class HeaderAstCreationPassTests extends C2CpgSuite { "de-duplicate content correctly" in { inside(cpg.method.nameNot(NamespaceTraversal.globalNamespaceName).sortBy(_.fullName)) { - case Seq(bar, foo, m1, m2, printf) => + case Seq(bar, foo, m, printf) => // note that we don't see bar twice even so it is contained // in main.h and included in main.c and we do scan both bar.fullName shouldBe "bar" bar.filename shouldBe "main.h" foo.fullName shouldBe "foo" foo.filename shouldBe "other.h" - // main is include twice. First time for the header file, - // second time for the actual implementation in the source file - // We do not de-duplicate this as line/column numbers differ - m1.fullName shouldBe "main" - m1.filename shouldBe "main.c" - m2.fullName shouldBe "main" - m2.filename shouldBe "main.h" + // main is also deduplicated. It is defined within the header file, + // and has an actual implementation in the source file + m.fullName shouldBe "main" + m.filename shouldBe "main.c" printf.fullName shouldBe "printf" } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/LambdaExpressionTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/LambdaExpressionTests.scala new file mode 100644 index 000000000000..59a3409ea410 --- /dev/null +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/LambdaExpressionTests.scala @@ -0,0 +1,655 @@ +package io.joern.c2cpg.passes.ast + +import io.joern.c2cpg.astcreation.Defines +import io.joern.c2cpg.parser.FileDefaults +import io.joern.c2cpg.testfixtures.AstC2CpgSuite +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.EvaluationStrategies +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal + +class LambdaExpressionTests extends AstC2CpgSuite(FileDefaults.CppExt) { + + "a simple lambda expression as argument" should { + val cpg = code(""" + |class Foo { + | public: + | string getFromSupplier(string input, std::function& mapper) { + | return mapper.apply(input); + | } + | + | void foo(string input, string fallback) { + | getFromSupplier( + | input, + | [fallback] (string lambdaInput) -> string { return lambdaInput.length() > 5 ? "Long" : fallback; } + | ); + | } + |} + |""".stripMargin) + + "create the correct typedecl node for the lambda" in { + cpg.typeDecl.name(".*lambda.*").name.l shouldBe List("0") + cpg.typeDecl.name(".*lambda.*").fullName.l shouldBe List("Test0.cpp:.Foo.foo.0:string(string)") + } + + "ref the lambda param correctly" in { + val List(lambdaMethod) = cpg.typeDecl.name("Foo").method.name(".*lambda.*").isLambda.l + val List(param) = lambdaMethod.parameter.l + val List(reffedParam) = cpg.identifier.nameExact("lambdaInput").refsTo.collectAll[MethodParameterIn].l + reffedParam shouldBe param + } + + "create a method node for the lambda" in { + cpg.typeDecl.name("Foo").method.name(".*lambda.*").isLambda.l match { + case List(lambdaMethod) => + lambdaMethod.name shouldBe "0" + lambdaMethod.fullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + lambdaMethod.parameter.l match { + case List(lambdaInput) => + lambdaInput.name shouldBe "lambdaInput" + lambdaInput.typeFullName shouldBe "string" + case result => fail(s"Expected single lambda parameter but got $result") + } + lambdaMethod.methodReturn.typeFullName shouldBe "string" + case result => fail(s"Expected single lambda method but got $result") + } + } + + "create a binding for the lambda method" in { + val List(typeDecl) = cpg.typeDecl.fullNameExact("Test0.cpp:.Foo.foo.0:string(string)").l + val List(binding) = typeDecl.bindsOut.l + binding.name shouldBe Defines.OperatorCall + binding.methodFullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + binding.signature shouldBe "string(string)" + val List(methodReffed) = binding.refOut.l + methodReffed.name shouldBe "0" + methodReffed.fullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + methodReffed.signature shouldBe "string(string)" + } + + "create a method body for the lambda method" in { + cpg.typeDecl.name("Foo").method.name(".*lambda.*").block.astChildren.l match { + case List(fallBack: Local, returnNode: Return) => + returnNode.code shouldBe "return lambdaInput.length() > 5 ? \"Long\" : fallback;" + returnNode.astChildren.l match { + case List(expr: Call) => + expr.methodFullName shouldBe Operators.conditional + case result => fail(s"Expected return conditional, but got $result") + } + fallBack.name shouldBe "fallback" + case result => fail(s"Expected lambda body with single return but got $result") + } + } + + "create locals for captured identifiers in the lambda method" in { + cpg.typeDecl.name("Foo").method.name(".*lambda.*").local.sortBy(_.name) match { + case Seq(fallbackLocal: Local) => + fallbackLocal.name shouldBe "fallback" + fallbackLocal.code shouldBe "fallback" + fallbackLocal.typeFullName shouldBe "string" + cpg.identifier.nameExact("fallback").refsTo.l shouldBe List(fallbackLocal) + case result => fail(s"Expected single local for fallback but got $result") + } + } + + "create closure bindings for captured identifiers" in { + cpg.all.collectAll[ClosureBinding].sortBy(_.closureOriginalName) match { + case Seq(fallbackClosureBinding) => + val fallbackLocal = cpg.method.name(".*lambda.*").local.name("fallback").head + fallbackClosureBinding.closureBindingId shouldBe fallbackLocal.closureBindingId + + fallbackClosureBinding._refOut.l match { + case List(capturedParam: MethodParameterIn) => + capturedParam.name shouldBe "fallback" + capturedParam.method.fullName shouldBe "Foo.foo:void(string,string)" + case result => fail(s"Expected single capturedParam but got $result") + } + + fallbackClosureBinding._captureIn.l match { + case List(outMethod: MethodRef) => + outMethod.typeFullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + outMethod.methodFullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + case result => fail(s"Expected single METHOD_REF but got $result") + } + case result => fail(s"Expected 1 closure binding for captured variables but got $result") + } + } + + "create a typeDecl node inheriting from correct interface" in { + cpg.typeDecl.name(".*lambda.*").l match { + case List(lambdaDecl) => + lambdaDecl.name shouldBe "0" + lambdaDecl.fullName shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + lambdaDecl.inheritsFromTypeFullName should contain theSameElementsAs List("std.function") + case result => fail(s"Expected a single typeDecl for the lambda but got $result") + } + } + } + + "lambdas with different return type annotations" should { + val cpg = code(""" + |void foo() { + | auto l1 = [] () -> int { return 1; }; // explicit type + | auto l2 = [] () { return 1; }; // inferred + | auto l3 = [] () -> unknown { return bar(); }; // broken or unknown + | auto l4 = [] () mutable -> int { return 1; }; + | auto l5 = [] () mutable { return 1; }; + |} + |""".stripMargin) + + "have the correct fullname" in { + val List(l0, l1, l2, l3, l4) = cpg.method.name(".*lambda.*").sortBy(_.name).l + l0.fullName shouldBe "Test0.cpp:.foo.0:int()" + l1.fullName shouldBe "Test0.cpp:.foo.1:ANY()" // CDT is unable to infer the type here; needs to be fixed + l2.fullName shouldBe "Test0.cpp:.foo.2:unknown()" + l3.fullName shouldBe "Test0.cpp:.foo.3:int()" + l4.fullName shouldBe "Test0.cpp:.foo.4:ANY()" + } + } + + "lambdas capturing this in method" should { + val cpg = code(""" + |class Foo { + | public: + | int firstDirty; + | void foo() { + | bar(l, [this] { return this->firstDirty == nullptr; }); + | } + |} + |""".stripMargin) + + "ref this correctly" in { + val List(lambda) = cpg.method.name(".*lambda.*").l + lambda.fullName shouldBe "Test0.cpp:.Foo.foo.0:ANY()" + cpg.all.collectAll[ClosureBinding].l match { + case List(thisClosureBinding) => + val thisLocal = cpg.method.name(".*lambda.*").local.nameExact("this").head + thisClosureBinding.closureBindingId shouldBe thisLocal.closureBindingId + + cpg.identifier.nameExact("this").refsTo.l shouldBe List(thisLocal) + + thisClosureBinding._refOut.l match { + case List(capturedThisParam: MethodParameterIn) => + capturedThisParam.name shouldBe "this" + capturedThisParam.typeFullName shouldBe "Foo*" + capturedThisParam.method.fullName shouldBe "Foo.foo:void()" + case result => fail(s"Expected single capturedParam but got $result") + } + + thisClosureBinding._captureIn.l match { + case List(outMethod: MethodRef) => + outMethod.typeFullName shouldBe lambda.fullName + outMethod.methodFullName shouldBe lambda.fullName + case result => fail(s"Expected single METHOD_REF but got $result") + } + case result => fail(s"Expected 1 closure binding for captured variables but got $result") + } + } + } + + "lambdas capturing with shadowing" should { + val cpg = code(""" + |static void foo(int *x) { + | auto f = [=] { float *x = nullptr; }; + |}""".stripMargin) + + "ref the shadowed variable correctly" in { + cpg.all.collectAll[ClosureBinding] shouldBe empty + val List(lambda) = cpg.method.name(".*lambda.*").l + val List(xLocal) = lambda.block.local.nameExact("x").l + xLocal.typeFullName shouldBe "float*" + xLocal.closureBindingId shouldBe None + xLocal.closureBinding shouldBe empty + cpg.identifier.nameExact("x").refsTo.l shouldBe List(xLocal) + } + } + + "lambdas capturing with shadowing in nested lambdas" should { + val cpg = code(""" + |static void foo(int *x) { + | auto f = [=] { + | x = nullptr; // x is captured + | auto nested = [=] { + | float *x = nullptr; // this x here is shadowed + | }; + | }; + |}""".stripMargin) + + "ref the shadowed variable correctly" in { + val List(lambdaF, lambdaNested) = cpg.method.name(".*lambda.*").sortBy(_.lineNumber.get).l + cpg.all.collectAll[ClosureBinding].l match { + case List(xClosureBinding) => + val List(xLocalCaptured) = lambdaF.block.local.nameExact("x").l + xClosureBinding.closureBindingId shouldBe xLocalCaptured.closureBindingId + + cpg.identifier.nameExact("x").lineNumber(4).refsTo.l shouldBe List(xLocalCaptured) + + xClosureBinding._refOut.l match { + case List(capturedThisParam: MethodParameterIn) => + capturedThisParam.name shouldBe "x" + capturedThisParam.typeFullName shouldBe "int*" + capturedThisParam.method.fullName shouldBe "foo:void(int*)" + case result => fail(s"Expected single capturedParam but got $result") + } + + xClosureBinding._captureIn.l match { + case List(outMethod: MethodRef) => + outMethod.typeFullName shouldBe lambdaF.fullName + outMethod.methodFullName shouldBe lambdaF.fullName + case result => fail(s"Expected single METHOD_REF but got $result") + } + case result => fail(s"Expected 1 closure binding for captured variables but got $result") + } + + val List(xLocalNested) = lambdaNested.block.local.nameExact("x").l + xLocalNested.typeFullName shouldBe "float*" + xLocalNested.closureBindingId shouldBe None + xLocalNested.closureBinding shouldBe empty + cpg.identifier.nameExact("x").lineNumber(6).refsTo.l shouldBe List(xLocalNested) + } + } + + "lambdas capturing with shadowing in nested blocks" should { + val cpg = code(""" + |static void foo(int *x) { + | auto f = [&] { + | *x = 0; // capture + | { + | float *x = nullptr; // first shadowing + | { + | double *x = nullptr; // second shadowing + | } + | } + | }; + |} + |""".stripMargin) + + "ref the shadowed variable correctly" in { + val List(x1, x2, x3) = cpg.method.name(".*lambda.*").ast.isLocal.sortBy(_.lineNumber.get).nameExact("x").l + + x1.typeFullName shouldBe "int*" + x1.closureBindingId shouldBe Some("Test0.cpp:0:x") + cpg.identifier.nameExact("x").lineNumber(4).refsTo.l shouldBe List(x1) + + x2.typeFullName shouldBe "float*" + x2.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(6).refsTo.l shouldBe List(x2) + + x3.typeFullName shouldBe "double*" + x3.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(8).refsTo.l shouldBe List(x3) + } + } + + "lambdas capturing with shadowing in nested blocks 2" should { + val cpg = code(""" + |static void foo(int *x) { + | auto f = [&] { + | *x = 0; // capture + | { + | float *x = nullptr; // shadowing + | { + | *x = 0; + | } + | } + | }; + |} + |""".stripMargin) + + "ref the shadowed variable correctly" in { + val List(x1, x2) = cpg.method.name(".*lambda.*").ast.isLocal.sortBy(_.lineNumber.get).nameExact("x").l + + x1.typeFullName shouldBe "int*" + x1.closureBindingId shouldBe Some("Test0.cpp:0:x") + cpg.identifier.nameExact("x").lineNumber(4).refsTo.l shouldBe List(x1) + + x2.typeFullName shouldBe "float*" + x2.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(6).refsTo.l shouldBe List(x2) + + cpg.identifier.nameExact("x").lineNumber(8).refsTo.l shouldBe List(x2) + } + } + + "lambdas capturing with shadowing in nested blocks 3" should { + val cpg = code(""" + |static void foo() { + | int *x; + | auto f = [&] { + | *x = 0; // capture + | { + | float *x = nullptr; // first shadowing + | { + | double *x = nullptr; // second shadowing + | } + | } + | }; + |} + |""".stripMargin) + + "ref the shadowed variable correctly" in { + val List(x1, x2, x3) = cpg.method.name(".*lambda.*").ast.isLocal.sortBy(_.lineNumber.get).nameExact("x").l + + x1.typeFullName shouldBe "int*" + x1.closureBindingId shouldBe Some("Test0.cpp:0:x") + cpg.identifier.nameExact("x").lineNumber(5).refsTo.l shouldBe List(x1) + + x2.typeFullName shouldBe "float*" + x2.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(7).refsTo.l shouldBe List(x2) + + x3.typeFullName shouldBe "double*" + x3.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(9).refsTo.l shouldBe List(x3) + } + } + + "lambdas capturing with shadowing in nested blocks 4" should { + val cpg = code(""" + |static void foo(int *x) { + | { + | float *x = nullptr; // first shadowing + | { + | double *x = nullptr; // second shadowing + | auto f = [&] { + | *x = 0.0L; // capture of double* + | }; + | } + | } + |} + |""".stripMargin) + + "ref the shadowed variable correctly" in { + val List(x1, x2, x3) = cpg.local.sortBy(_.lineNumber.get).nameExact("x").l + + x1.typeFullName shouldBe "float*" + x1.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(4).refsTo.l shouldBe List(x1) + + x2.typeFullName shouldBe "double*" + x2.closureBindingId shouldBe None + cpg.identifier.nameExact("x").lineNumber(6).refsTo.l shouldBe List(x2) + + x3.typeFullName shouldBe "double*" + x3.closureBindingId shouldBe Some("Test0.cpp:0:x") + cpg.identifier.nameExact("x").lineNumber(8).refsTo.l shouldBe List(x3) + } + } + + "lambdas capturing this in global method" should { + val cpg = code(""" + |class Foo {} + |void Foo::foo() { + | bar(l, [this] { return this->firstDirty == nullptr; }); + |} + |""".stripMargin) + + "ref this correctly" in { + val List(lambda) = cpg.method.name(".*lambda.*").l + cpg.all.collectAll[ClosureBinding].l match { + case List(thisClosureBinding) => + val thisLocal = cpg.method.name(".*lambda.*").local.nameExact("this").head + thisClosureBinding.closureBindingId shouldBe thisLocal.closureBindingId + + cpg.identifier.nameExact("this").refsTo.l shouldBe List(thisLocal) + + thisClosureBinding._refOut.l match { + case List(capturedThisParam: MethodParameterIn) => + capturedThisParam.name shouldBe "this" + capturedThisParam.typeFullName shouldBe "Foo*" + capturedThisParam.method.fullName shouldBe "Foo.foo:void()" + case result => fail(s"Expected single capturedParam but got $result") + } + + thisClosureBinding._captureIn.l match { + case List(outMethod: MethodRef) => + outMethod.typeFullName shouldBe lambda.fullName + outMethod.methodFullName shouldBe lambda.fullName + case result => fail(s"Expected single METHOD_REF but got $result") + } + case result => fail(s"Expected 1 closure binding for captured variables but got $result") + } + } + } + + "lambda capturing local variable by value" should { + val cpg = code(""" + |class Foo { + | public: + | void foo(Object arg) { + | string myValue = "abc"; + | std::list userPayload = {}; + | auto userNamesList = userPayload.map([myValue] (string item) -> string { + | sink2(myValue); + | return item + myValue; + | }); + | sink1(userNamesList); + | return; + | } + |} + |""".stripMargin) + + "be captured precisely" in { + cpg.all.collectAll[ClosureBinding].l match { + case myValue :: Nil => + myValue.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE + myValue.closureOriginalName.head shouldBe "myValue" + myValue._localViaRefOut.get.name shouldBe "myValue" + myValue._captureIn.collectFirst { case x: MethodRef => + x.methodFullName + }.head shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + case result => + fail(s"Expected single closure binding to collect but got $result") + } + } + + } + + "lambda capturing local variable by reference" should { + val cpg = code(""" + |class Foo { + | public: + | void foo(Object arg) { + | string myValue = "abc"; + | std::list userPayload = {}; + | auto userNamesList = userPayload.map([&] (string item) -> string { + | sink2(myValue); + | return item + myValue; + | }); + | sink1(userNamesList); + | return; + | } + |} + |""".stripMargin) + + "be captured precisely" in { + cpg.all.collectAll[ClosureBinding].l match { + case myValue :: Nil => + myValue.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE + myValue.closureOriginalName.head shouldBe "myValue" + myValue._localViaRefOut.get.name shouldBe "myValue" + myValue._captureIn.collectFirst { case x: MethodRef => + x.methodFullName + }.head shouldBe "Test0.cpp:.Foo.foo.0:string(string)" + case result => + fail(s"Expected single closure binding to collect but got $result") + } + } + + } + + "be correct for simple lambda expressions" in { + val cpg = code(""" + |auto x = [] (int a, int b) -> int + | { return a + b; }; + |auto y = [] (string a, string b) -> string + | { return a + b; }; + |""".stripMargin) + val lambda1FullName = "Test0.cpp:.0:int(int,int)" + val lambda2FullName = "Test0.cpp:.1:string(string,string)" + + cpg.local.nameExact("x").order.l shouldBe List(1) + cpg.local.nameExact("y").order.l shouldBe List(3) + + inside(cpg.assignment.l) { case List(assignment1, assignment2) => + assignment1.order shouldBe 2 + inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => + ref.methodFullName shouldBe lambda1FullName + } + assignment2.order shouldBe 4 + inside(assignment2.astMinusRoot.isMethodRef.l) { case List(ref) => + ref.methodFullName shouldBe lambda2FullName + } + } + + inside(cpg.method.fullNameExact(lambda1FullName).isLambda.l) { case List(l1) => + l1.name shouldBe "0" + l1.code should startWith("[] (int a, int b) -> int") + l1.signature shouldBe "int(int,int)" + l1.body.code shouldBe "{ return a + b; }" + } + + inside(cpg.method.fullNameExact(lambda2FullName).isLambda.l) { case List(l2) => + l2.name shouldBe "1" + l2.code should startWith("[] (string a, string b) -> string") + l2.signature shouldBe "string(string,string)" + l2.body.code shouldBe "{ return a + b; }" + } + cpg.typeDecl(NamespaceTraversal.globalNamespaceName).head.bindsOut.size shouldBe 0 + } + + "be correct for simple lambda expression in class" in { + val cpg = code(""" + |class Foo { + | auto x = [] (int a, int b) -> int + | { + | return a + b; + | }; + |}; + | + |""".stripMargin) + val lambdaName = "0" + val signature = "int(int,int)" + val lambdaFullName = s"Test0.cpp:.Foo.$lambdaName:$signature" + + cpg.member.nameExact("x").order.l shouldBe List(1) + + inside(cpg.assignment.l) { case List(assignment1) => + inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => + ref.methodFullName shouldBe lambdaFullName + } + } + + inside(cpg.method.fullNameExact(lambdaFullName).isLambda.l) { case List(l1) => + l1.name shouldBe lambdaName + l1.code should startWith("[] (int a, int b) -> int") + l1.signature shouldBe signature + } + } + + "be correct for simple lambda expression in class under namespaces" in { + val cpg = code(""" + |namespace A { class B { + |class Foo { + | auto x = [] (int a, int b) -> int + | { + | return a + b; + | }; + |}; + |};} + |""".stripMargin) + val lambdaName = "0" + val signature = "int(int,int)" + val lambdaFullName = s"Test0.cpp:.A.B.Foo.$lambdaName:$signature" + + cpg.member.nameExact("x").order.l shouldBe List(1) + + inside(cpg.assignment.l) { case List(assignment1) => + inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => + ref.methodFullName shouldBe lambdaFullName + } + } + + inside(cpg.method.fullNameExact(lambdaFullName).isLambda.l) { case List(l1) => + l1.name shouldBe lambdaName + l1.code should startWith("[] (int a, int b) -> int") + l1.signature shouldBe signature + } + + } + + "be correct when calling a lambda" in { + val cpg = code(""" + |auto x = [](int n) -> int + |{ + | return 32 + n; + |}; + | + |constexpr int foo1 = x(10); + |constexpr int foo2 = [](int n) -> int + |{ + | return 32 + n; + |}(10); + |""".stripMargin) + val signature = "int(int)" + val lambda1Name = "0" + val lambda1FullName = s"Test0.cpp:.$lambda1Name:$signature" + val lambda2Name = "1" + val lambda2FullName = s"Test0.cpp:.$lambda2Name:$signature" + + cpg.local.nameExact("x").order.l shouldBe List(1) + cpg.local.nameExact("foo1").order.l shouldBe List(3) + cpg.local.nameExact("foo2").order.l shouldBe List(5) + + inside(cpg.assignment.l) { case List(assignment1, assignment2, assignment3) => + assignment1.order shouldBe 2 + assignment2.order shouldBe 4 + assignment3.order shouldBe 6 + inside(assignment1.astMinusRoot.isMethodRef.l) { case List(ref) => + ref.methodFullName shouldBe lambda1FullName + } + } + + inside(cpg.method.fullNameExact(lambda1FullName).isLambda.l) { case List(l1) => + l1.name shouldBe lambda1Name + l1.code should startWith("[](int n) -> int") + l1.signature shouldBe signature + } + + inside(cpg.call.nameExact("()").l) { case List(lambda1call, lambda2call) => + lambda1call.name shouldBe "()" + lambda1call.methodFullName shouldBe "():int(int)" + lambda1call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + inside(lambda1call.astChildren.l) { case List(id: Identifier, lit: Literal) => + id.code shouldBe "x" + lit.code shouldBe "10" + } + inside(lambda1call.argument.l) { case List(lit: Literal) => + lit.code shouldBe "10" + } + inside(lambda1call.receiver.l) { case List(receiver: Identifier) => + receiver.code shouldBe "x" + } + + lambda2call.name shouldBe "()" + lambda2call.methodFullName shouldBe "():int(int)" + lambda2call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + inside(lambda2call.astChildren.l) { case List(ref: MethodRef, lit: Literal) => + ref.methodFullName shouldBe lambda2FullName + ref.code shouldBe lambda2FullName + lit.code shouldBe "10" + } + + inside(lambda2call.argument.l) { case List(lit: Literal) => + lit.code shouldBe "10" + } + inside(lambda2call.receiver.l) { case List(ref: MethodRef) => + ref.methodFullName shouldBe lambda2FullName + ref.code shouldBe lambda2FullName + } + } + } + +} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MemberTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MemberTests.scala index 82b91be23f27..6117f4ae71e0 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MemberTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MemberTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MemberTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MetaDataTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MetaDataTests.scala index 0724f67523f9..c3c91d81c19b 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MetaDataTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MetaDataTests.scala @@ -6,7 +6,7 @@ import io.joern.x2cpg.layers.CallGraph import io.joern.x2cpg.layers.ControlFlow import io.joern.x2cpg.layers.TypeRelations import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodParameterTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodParameterTests.scala index 4c4535132524..d3c12d8c82c2 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodParameterTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodParameterTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodParameterTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodReturnTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodReturnTests.scala index 64ab131c7323..ca33f31a7595 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodReturnTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodReturnTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class MethodReturnTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodTests.scala index 194649699a10..514094d9ac15 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/MethodTests.scala @@ -3,7 +3,9 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite import io.shiftleft.codepropertygraph.generated.EvaluationStrategies import io.shiftleft.codepropertygraph.generated.NodeTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.Identifier +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class MethodTests extends C2CpgSuite { @@ -92,7 +94,7 @@ class MethodTests extends C2CpgSuite { data.index shouldBe 1 data.name shouldBe "data" data.code shouldBe "int &data" - data.typeFullName shouldBe "int" + data.typeFullName shouldBe "int&" data.isVariadic shouldBe false } } @@ -270,8 +272,53 @@ class MethodTests extends C2CpgSuite { } + "Static modifier for methods" should { + "be correct" in { + val cpg = code( + """ + |static void staticCMethodDecl(); + |static void staticCMethodDef() {} + |""".stripMargin, + "test.c" + ).moreCode( + """ + |class A { + | static void staticCPPMethodDecl(); + | static void staticCPPMethodDef() {} + |}; + |""".stripMargin, + "test.cpp" + ) + val List(staticCMethodDecl) = cpg.method.nameExact("staticCMethodDecl").isStatic.l + val List(staticCMethodDef) = cpg.method.nameExact("staticCMethodDef").isStatic.l + val List(staticCPPMethodDecl) = cpg.method.nameExact("staticCPPMethodDecl").isStatic.l + val List(staticCPPMethodDef) = cpg.method.nameExact("staticCPPMethodDef").isStatic.l + staticCMethodDecl.fullName shouldBe "staticCMethodDecl" + staticCMethodDef.fullName shouldBe "staticCMethodDef" + staticCPPMethodDecl.fullName shouldBe "A.staticCPPMethodDecl:void()" + staticCPPMethodDef.fullName shouldBe "A.staticCPPMethodDef:void()" + } + } + + "Name for method parameter in parentheses" should { + "be correct" in { + val cpg = code(""" + |int foo(int * (a)) { + | int (x) = a; + | return 2 * *a; + |} + |""".stripMargin) + val List(paramA) = cpg.method("foo").parameter.l + paramA.code shouldBe "int * (a)" + paramA.typeFullName shouldBe "int*" + paramA.name shouldBe "a" + cpg.identifier.nameExact("x").size shouldBe 1 + cpg.method("foo").local.nameExact("x").size shouldBe 1 + } + } + "Method name, signature and full name tests" should { - "be correct for plain method C" in { + "be correct for plain C method" in { val cpg = code( """ |int method(int); @@ -283,6 +330,21 @@ class MethodTests extends C2CpgSuite { method.fullName shouldBe "method" } + "be correct for C function pointer" in { + val cpg = code( + """ + |int (*foo)(int, int) = { 0 }; + |int (*bar[])(int, int) = { 0 }; + |""".stripMargin, + "test.c" + ) + val List(foo, bar) = cpg.local.l + foo.name shouldBe "foo" + foo.typeFullName shouldBe "int(*)(int,int)" + bar.name shouldBe "bar" + bar.typeFullName shouldBe "int(*[])(int,int)" + } + "be correct for plain method CPP" in { val cpg = code( """ @@ -328,5 +390,66 @@ class MethodTests extends C2CpgSuite { method.signature shouldBe "int(int)" method.fullName shouldBe "NNN.CCC.method:int(int)" } + + "be correct for class method with implicit member access" in { + val cpg = code( + """ + |class A { + | int var; + | void meth(); + |}; + |namespace Foo { + | void A::meth() { + | assert(this->var == var); + | } + |}""".stripMargin, + "test.cpp" + ) + val List(implicitThisParam) = cpg.method.name("meth").parameter.l + implicitThisParam.name shouldBe "this" + implicitThisParam.typeFullName shouldBe "A*" + val List(trueVarAccess) = cpg.call.name(Operators.equals).argument.argumentIndex(1).isCall.l + trueVarAccess.code shouldBe "this->var" + trueVarAccess.name shouldBe Operators.indirectFieldAccess + val List(trueThisId, trueVarFieldIdent) = trueVarAccess.argument.l + trueThisId.code shouldBe "this" + trueThisId.isIdentifier shouldBe true + trueThisId.asInstanceOf[Identifier].typeFullName shouldBe "A*" + trueThisId._refOut.l shouldBe List(implicitThisParam) + trueVarFieldIdent.code shouldBe "var" + trueVarFieldIdent.isFieldIdentifier shouldBe true + + val List(varAccess) = cpg.call.name(Operators.equals).argument.argumentIndex(2).isCall.l + varAccess.code shouldBe "this->var" + varAccess.name shouldBe Operators.indirectFieldAccess + val List(thisId, varFieldIdent) = varAccess.argument.l + thisId.code shouldBe "this" + thisId.isIdentifier shouldBe true + thisId.asInstanceOf[Identifier].typeFullName shouldBe "A*" + thisId._refOut.l shouldBe List(implicitThisParam) + varFieldIdent.code shouldBe "var" + varFieldIdent.isFieldIdentifier shouldBe true + } + + "be correct for class method in nested class" in { + val cpg = code( + """class Outer { + | class Inner { + | void Method(); + | int member; + | }; + |}; + |void Outer::Inner::Method() { + | member; + |}""".stripMargin, + "test.cpp" + ) + cpg.identifier.name("member").size shouldBe 0 + val List(memberCall) = cpg.call.codeExact("this->member").l + memberCall.typeFullName shouldBe "int" + memberCall.name shouldBe Operators.indirectFieldAccess + memberCall.argument.isIdentifier.typeFullName.l shouldBe List("Outer.Inner*") + memberCall.argument.isFieldIdentifier.code.l shouldBe List("member") + } } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/NamespaceBlockTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/NamespaceBlockTests.scala index 200c7bed77ce..335c187314d6 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/NamespaceBlockTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/NamespaceBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ProgramStructureTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ProgramStructureTests.scala index cda8183dcc8a..bef636625dac 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ProgramStructureTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/ast/ProgramStructureTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.passes.ast import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class ProgramStructureTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/CfgCreationPassTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/CfgCreationPassTests.scala index 6ade9688aa0e..82638bb2415a 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/CfgCreationPassTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/CfgCreationPassTests.scala @@ -14,319 +14,341 @@ class CfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg) { "Cfg" should { "contain an entry and exit node at least" in { implicit val cpg: Cpg = code("") - succOf("func") shouldBe expected(("RET", AlwaysEdge)) - succOf("RET") shouldBe expected() + succOf("func") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("RET") should contain theSameElementsAs expected() } "be correct for decl statement with assignment" in { implicit val cpg: Cpg = code("int x = 1;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x = 1", AlwaysEdge)) - succOf("x = 1") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x = 1", AlwaysEdge)) + succOf("x = 1") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested expression" in { implicit val cpg: Cpg = code("x = y + 1;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y + 1", AlwaysEdge)) - succOf("y + 1") shouldBe expected(("x = y + 1", AlwaysEdge)) - succOf("x = y + 1") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y + 1", AlwaysEdge)) + succOf("y + 1") should contain theSameElementsAs expected(("x = y + 1", AlwaysEdge)) + succOf("x = y + 1") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for return statement" in { implicit val cpg: Cpg = code("return x;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("return x;", AlwaysEdge)) - succOf("return x;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("return x;", AlwaysEdge)) + succOf("return x;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for consecutive return statements" in { implicit val cpg: Cpg = code("return x; return y;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("return x;", AlwaysEdge)) - succOf("y") shouldBe expected(("return y;", AlwaysEdge)) - succOf("return x;") shouldBe expected(("RET", AlwaysEdge)) - succOf("return y;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("return x;", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("return y;", AlwaysEdge)) + succOf("return x;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("return y;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for void return statement" in { implicit val cpg: Cpg = code("return;") - succOf("func") shouldBe expected(("return;", AlwaysEdge)) - succOf("return;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("return;", AlwaysEdge)) + succOf("return;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for call expression" in { implicit val cpg: Cpg = code("foo(a + 1, b);") - succOf("func") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("a + 1", AlwaysEdge)) - succOf("a + 1") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("foo(a + 1, b)", AlwaysEdge)) - succOf("foo(a + 1, b)") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("a + 1", AlwaysEdge)) + succOf("a + 1") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("foo(a + 1, b)", AlwaysEdge)) + succOf("foo(a + 1, b)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for unary expression '+'" in { implicit val cpg: Cpg = code("+x;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("+x", AlwaysEdge)) - succOf("+x") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("+x", AlwaysEdge)) + succOf("+x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for unary expression '++'" in { implicit val cpg: Cpg = code("++x;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("++x", AlwaysEdge)) - succOf("++x") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("++x", AlwaysEdge)) + succOf("++x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for conditional expression" in { implicit val cpg: Cpg = code("x ? y : z;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) - succOf("y") shouldBe expected(("x ? y : z", AlwaysEdge)) - succOf("z") shouldBe expected(("x ? y : z", AlwaysEdge)) - succOf("x ? y : z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("x ? y : z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("x ? y : z", AlwaysEdge)) + succOf("x ? y : z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for conditional expression with empty then" in { implicit val cpg: Cpg = code("x ? : z;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("x ? : z", TrueEdge), ("z", FalseEdge)) - succOf("z") shouldBe expected(("x ? : z", AlwaysEdge)) - succOf("x ? : z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("x ? : z", TrueEdge), ("z", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("x ? : z", AlwaysEdge)) + succOf("x ? : z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for short-circuit AND expression" in { implicit val cpg: Cpg = code("int z = x && y;") - succOf("func") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("x && y", FalseEdge)) - succOf("y") shouldBe expected(("x && y", AlwaysEdge)) - succOf("x && y") shouldBe expected(("z = x && y", AlwaysEdge)) - succOf("z = x && y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("x && y", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("x && y", AlwaysEdge)) + succOf("x && y") should contain theSameElementsAs expected(("z = x && y", AlwaysEdge)) + succOf("z = x && y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for short-circuit OR expression" in { implicit val cpg: Cpg = code("x || y;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", FalseEdge), ("x || y", TrueEdge)) - succOf("y") shouldBe expected(("x || y", AlwaysEdge)) - succOf("x || y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", FalseEdge), ("x || y", TrueEdge)) + succOf("y") should contain theSameElementsAs expected(("x || y", AlwaysEdge)) + succOf("x || y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } "Cfg for while-loop" should { "be correct" in { implicit val cpg: Cpg = code("while (x < 1) { y = 2; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) - succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y = 2", AlwaysEdge)) + succOf("y = 2") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct with break" in { implicit val cpg: Cpg = code("while (x < 1) { break; y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct with continue" in { implicit val cpg: Cpg = code("while (x < 1) { continue; y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) - succOf("continue;") shouldBe expected(("x", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("continue;") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct with nested while-loop" in { implicit val cpg: Cpg = code("while (x) { while (y) { z; }}") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("z", TrueEdge), ("x", FalseEdge)) - succOf("z") shouldBe expected(("y", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("z", TrueEdge), ("x", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("y", AlwaysEdge)) } } "Cfg for do-while-loop" should { "be correct" in { implicit val cpg: Cpg = code("do { y = 2; } while (x < 1);") - succOf("func") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) - succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y = 2", AlwaysEdge)) + succOf("y = 2") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) } "be correct with break" in { implicit val cpg: Cpg = code("do { break; y; } while (x < 1);") - succOf("func") shouldBe expected(("break;", AlwaysEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("break;", AlwaysEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) } "be correct with continue" in { implicit val cpg: Cpg = code("do { continue; y; } while (x < 1);") - succOf("func") shouldBe expected(("continue;", AlwaysEdge)) - succOf("continue;") shouldBe expected(("x", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("continue;", AlwaysEdge)) + succOf("continue;") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) } "be correct with nested do-while-loop" in { implicit val cpg: Cpg = code("do { do { x; } while (y); } while (z);") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("x", TrueEdge), ("z", FalseEdge)) - succOf("z") shouldBe expected(("x", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", TrueEdge), ("z", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("x", TrueEdge), ("RET", FalseEdge)) } "be correct for do-while-loop with empty body" in { implicit val cpg: Cpg = code("do { } while(x > 1);") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("1") shouldBe expected(("x > 1", AlwaysEdge)) - succOf("x > 1") shouldBe expected(("x", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x > 1", AlwaysEdge)) + succOf("x > 1") should contain theSameElementsAs expected(("x", TrueEdge), ("RET", FalseEdge)) + } + + "be correct with multiple macro calls" in { + implicit val cpg: Cpg = code( + """ + |#define deleteReset(ptr) do { delete ptr; ptr = nullptr; } while(0) + |void func(void) { + | int *foo = new int; + | int *bar = new int; + | int *baz = new int; + | deleteReset(foo); + | deleteReset(bar); + | deleteReset(baz); + |} + |""".stripMargin, + "foo.cc" + ) + succOf("deleteReset(foo)") should contain theSameElementsAs expected(("foo", 2, AlwaysEdge), ("bar", AlwaysEdge)) + succOf("foo", 2) should contain theSameElementsAs expected(("delete foo", AlwaysEdge)) + succOf("deleteReset(bar)") should contain theSameElementsAs expected(("bar", 2, AlwaysEdge), ("baz", AlwaysEdge)) + succOf("bar", 2) should contain theSameElementsAs expected(("delete bar", AlwaysEdge)) + succOf("deleteReset(baz)") should contain theSameElementsAs expected(("baz", 2, AlwaysEdge), ("RET", AlwaysEdge)) + succOf("baz", 2) should contain theSameElementsAs expected(("delete baz", AlwaysEdge)) } - } "Cfg for for-loop" should { "be correct" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { a = 3; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("a", TrueEdge), ("RET", FalseEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("x = 0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("a", TrueEdge), ("RET", FalseEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct with break" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { break; a = 3; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct with continue" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { continue; a = 3; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) - succOf("continue;") shouldBe expected(("z", AlwaysEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("x = 0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("continue;") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct with nested for-loop" in { implicit val cpg: Cpg = code("for (x; y; z) { for (a; b; c) { u; } }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("a", TrueEdge), ("RET", FalseEdge)) - succOf("z") shouldBe expected(("y", AlwaysEdge)) - succOf("a") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("u", TrueEdge), ("z", FalseEdge)) - succOf("c") shouldBe expected(("b", AlwaysEdge)) - succOf("u") shouldBe expected(("c", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("a", TrueEdge), ("RET", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("u", TrueEdge), ("z", FalseEdge)) + succOf("c") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("u") should contain theSameElementsAs expected(("c", AlwaysEdge)) } "be correct with empty condition" in { implicit val cpg: Cpg = code("for (;;) { a = 1; }") - succOf("func") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("a = 1", AlwaysEdge)) - succOf("a = 1") shouldBe expected(("a", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("a = 1", AlwaysEdge)) + succOf("a = 1") should contain theSameElementsAs expected(("a", AlwaysEdge)) } "be correct with empty condition with break" in { implicit val cpg: Cpg = code("for (;;) { break; }") - succOf("func") shouldBe expected(("break;", AlwaysEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("break;", AlwaysEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with empty condition with continue" in { implicit val cpg: Cpg = code("for (;;) { continue ; }") - succOf("func") shouldBe expected(("continue ;", AlwaysEdge)) - succOf("continue ;") shouldBe expected(("continue ;", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("continue ;", AlwaysEdge)) + succOf("continue ;") should contain theSameElementsAs expected(("continue ;", AlwaysEdge)) } "be correct with empty condition with nested empty for-loop" in { implicit val cpg: Cpg = code("for (;;) { for (;;) { x; } }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("x", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct with empty condition with empty block" in { implicit val cpg: Cpg = code("for (;;) ;") - succOf("func") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct when empty for-loop is skipped" in { implicit val cpg: Cpg = code("for (;;) {}; return;") - succOf("func") shouldBe expected(("return;", AlwaysEdge)) - succOf("return;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("return;", AlwaysEdge)) + succOf("return;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with function call condition with empty block" in { implicit val cpg: Cpg = code("for (; x(1);) ;") - succOf("func") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x(1)", AlwaysEdge)) - succOf("x(1)") shouldBe expected(("1", TrueEdge), ("RET", FalseEdge)) + succOf("func") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x(1)", AlwaysEdge)) + succOf("x(1)") should contain theSameElementsAs expected(("1", TrueEdge), ("RET", FalseEdge)) } } "Cfg for goto" should { "be correct for single label" in { implicit val cpg: Cpg = code("x; goto l1; y; l1: ;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) - succOf("goto l1;") shouldBe expected(("l1: ;", AlwaysEdge)) - succOf("l1: ;") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("l1: ;", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("goto l1;", AlwaysEdge)) + succOf("goto l1;") should contain theSameElementsAs expected(("l1: ;", AlwaysEdge)) + succOf("l1: ;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("l1: ;", AlwaysEdge)) } "be correct for GNU goto labels as values" in { @@ -336,40 +358,40 @@ class CfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg) { |otherCall(); |foo: someCall(); |""".stripMargin) - succOf("func") shouldBe expected(("ptr", AlwaysEdge)) - succOf("ptr") shouldBe expected(("foo", AlwaysEdge)) - succOf("ptr", 1) shouldBe expected(("*ptr", AlwaysEdge)) - succOf("foo") shouldBe expected(("&&foo", AlwaysEdge)) - succOf("*ptr = &&foo") shouldBe expected(("goto *;", AlwaysEdge)) - succOf("goto *;") shouldBe expected(("foo: someCall();", AlwaysEdge)) - succOf("foo: someCall();") shouldBe expected(("someCall()", AlwaysEdge)) - succOf("otherCall()") shouldBe expected(("foo: someCall();", AlwaysEdge)) - succOf("someCall()") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("ptr", AlwaysEdge)) + succOf("ptr") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("ptr", 1) should contain theSameElementsAs expected(("*ptr", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("&&foo", AlwaysEdge)) + succOf("*ptr = &&foo") should contain theSameElementsAs expected(("goto *;", AlwaysEdge)) + succOf("goto *;") should contain theSameElementsAs expected(("foo: someCall();", AlwaysEdge)) + succOf("foo: someCall();") should contain theSameElementsAs expected(("someCall()", AlwaysEdge)) + succOf("otherCall()") should contain theSameElementsAs expected(("foo: someCall();", AlwaysEdge)) + succOf("someCall()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for multiple labels" in { implicit val cpg: Cpg = code("x; goto l1; l2: y; l1: ;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) - succOf("goto l1;") shouldBe expected(("l1: ;", AlwaysEdge)) - succOf("y") shouldBe expected(("l1: ;", AlwaysEdge)) - succOf("l1: ;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("goto l1;", AlwaysEdge)) + succOf("goto l1;") should contain theSameElementsAs expected(("l1: ;", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("l1: ;", AlwaysEdge)) + succOf("l1: ;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for multiple labels on same spot" in { implicit val cpg: Cpg = code("x; goto l2; y; l1: ;l2: ;") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("goto l2;", AlwaysEdge)) - succOf("goto l2;") shouldBe expected(("l2: ;", AlwaysEdge)) - succOf("y") shouldBe expected(("l1: ;", AlwaysEdge)) - succOf("l1: ;") shouldBe expected(("l2: ;", AlwaysEdge)) - succOf("l2: ;") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("goto l2;", AlwaysEdge)) + succOf("goto l2;") should contain theSameElementsAs expected(("l2: ;", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("l1: ;", AlwaysEdge)) + succOf("l1: ;") should contain theSameElementsAs expected(("l2: ;", AlwaysEdge)) + succOf("l2: ;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "work correctly with if block" in { implicit val cpg: Cpg = code("if(foo) goto end; if(bar) { f(x); } end: ;") - succOf("func") shouldBe expected(("foo", AlwaysEdge)) - succOf("goto end;") shouldBe expected(("end: ;", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("goto end;") should contain theSameElementsAs expected(("end: ;", AlwaysEdge)) } } @@ -377,85 +399,93 @@ class CfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg) { "Cfg for switch" should { "be correct with one case" in { implicit val cpg: Cpg = code("switch (x) { case 1: y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("case 1:", CaseEdge), ("RET", CaseEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with multiple cases" in { implicit val cpg: Cpg = code("switch (x) { case 1: y; case 2: z;}") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected( + ("case 1:", CaseEdge), + ("case 2:", CaseEdge), + ("RET", CaseEdge) + ) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("case 2:", AlwaysEdge)) + succOf("case 2:") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with multiple cases on same spot" in { implicit val cpg: Cpg = code("switch (x) { case 1: case 2: y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected( + ("case 1:", CaseEdge), + ("case 2:", CaseEdge), + ("RET", CaseEdge) + ) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("case 2:", AlwaysEdge)) + succOf("case 2:") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with multiple cases and multiple cases on same spot" in { implicit val cpg: Cpg = code("switch (x) { case 1: case 2: y; case 3: z;}") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected( + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected( ("case 1:", CaseEdge), ("case 2:", CaseEdge), ("case 3:", CaseEdge), ("RET", CaseEdge) ) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("case 3:", AlwaysEdge)) - succOf("case 3:") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("case 2:", AlwaysEdge)) + succOf("case 2:") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("case 3:", AlwaysEdge)) + succOf("case 3:") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with default case" in { implicit val cpg: Cpg = code("switch (x) { default: y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("default:", CaseEdge)) - succOf("default:") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("default:", CaseEdge)) + succOf("default:") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for case and default combined" in { implicit val cpg: Cpg = code("switch (x) { case 1: y; break; default: z;}") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("default:", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("break;", AlwaysEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("default:") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("case 1:", CaseEdge), ("default:", CaseEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", AlwaysEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("default:") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested switch" in { implicit val cpg: Cpg = code("switch (x) { case 1: switch(y) { default: z; } }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("RET", AlwaysEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("default:", CaseEdge)) - succOf("default:") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("case 1:", CaseEdge), ("RET", AlwaysEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("default:", CaseEdge)) + succOf("default:") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch containing continue statement" in { @@ -467,47 +497,67 @@ class CfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg) { | } |} |""".stripMargin) - succOf("continue;") shouldBe expected(("i", AlwaysEdge)) + succOf("continue;") should contain theSameElementsAs expected(("i", AlwaysEdge)) } } "Cfg for if" should { "be correct" in { implicit val cpg: Cpg = code("if (x) { y; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with else block" in { implicit val cpg: Cpg = code("if (x) { y; } else { z; }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with nested if" in { implicit val cpg: Cpg = code("if (x) { if (y) { z; } }") - succOf("func") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("z", TrueEdge), ("RET", FalseEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("z", TrueEdge), ("RET", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct with else if chain" in { implicit val cpg: Cpg = code("if (a) { b; } else if (c) { d;} else { e; }") - succOf("func") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("b", TrueEdge), ("c", FalseEdge)) - succOf("b") shouldBe expected(("RET", AlwaysEdge)) - succOf("c") shouldBe expected(("d", TrueEdge), ("e", FalseEdge)) - succOf("d") shouldBe expected(("RET", AlwaysEdge)) - succOf("e") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", TrueEdge), ("c", FalseEdge)) + succOf("b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("c") should contain theSameElementsAs expected(("d", TrueEdge), ("e", FalseEdge)) + succOf("d") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("e") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + } + + "be correct for empty 'then' block" in { + implicit val cpg: Cpg = code("if (cond()) {} else { foo(); }") + succOf("func") should contain theSameElementsAs expected(("cond()", AlwaysEdge)) + succOf("cond()") should contain theSameElementsAs expected(("RET", TrueEdge), ("foo()", FalseEdge)) + succOf("foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + } + + "be correct for empty 'else' block" in { + implicit val cpg: Cpg = code("if (cond()) {foo();} else {}") + succOf("func") should contain theSameElementsAs expected(("cond()", AlwaysEdge)) + succOf("cond()") should contain theSameElementsAs expected(("RET", FalseEdge), ("foo()", TrueEdge)) + succOf("foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + } + + "be correct for empty 'then' and 'else' block" in { + implicit val cpg: Cpg = code("if (cond()) {} else {}") + succOf("func") should contain theSameElementsAs expected(("cond()", AlwaysEdge)) + succOf("cond()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } } -class CppCfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg(FileDefaults.CPP_EXT)) { +class CppCfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg(FileDefaults.CppExt)) { override def code(code: String): CCfgTestCpg = { super.code(s"RET func() { $code }") } @@ -516,9 +566,9 @@ class CppCfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg(FileD "be correct for try with a single catch" in { implicit val cpg: Cpg = code("try { a; } catch (int x) { b; }") - succOf("func") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("b", AlwaysEdge), ("RET", AlwaysEdge)) - succOf("b") shouldBe expected(("RET", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", AlwaysEdge), ("RET", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for try with multiple catches" in { @@ -533,13 +583,41 @@ class CppCfgCreationPassTests extends CfgTestFixture(() => new CCfgTestCpg(FileD | d; |} |""".stripMargin) - succOf("func") shouldBe expected(("a", AlwaysEdge)) + succOf("func") should contain theSameElementsAs expected(("a", AlwaysEdge)) // Try should have an edge to all catches and return - succOf("a") shouldBe expected(("b", AlwaysEdge), ("c", AlwaysEdge), ("d", AlwaysEdge), ("RET", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected( + ("b", AlwaysEdge), + ("c", AlwaysEdge), + ("d", AlwaysEdge), + ("RET", AlwaysEdge) + ) // But catches should only have edges to return - succOf("b") shouldBe expected(("RET", AlwaysEdge)) - succOf("c") shouldBe expected(("RET", AlwaysEdge)) - succOf("d") shouldBe expected(("RET", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("c") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("d") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + } + + "be correct for throw statement" in { + implicit val cpg: Cpg = code(""" + |throw foo(); + |bar(); + |""".stripMargin) + succOf("func") should contain theSameElementsAs expected(("foo()", AlwaysEdge)) + succOf("foo()") should contain theSameElementsAs expected(("throw foo()", AlwaysEdge)) + succOf("throw foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("bar()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + } + + "be correct for throw statement in if-else" in { + implicit val cpg: Cpg = code(""" + |if (true) throw foo(); + |else bar(); + |""".stripMargin) + succOf("func") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("foo()", TrueEdge), ("bar()", FalseEdge)) + succOf("foo()") should contain theSameElementsAs expected(("throw foo()", AlwaysEdge)) + succOf("throw foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("bar()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/MethodCfgLayoutTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/MethodCfgLayoutTests.scala index d2c5042faa88..27b285307716 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/MethodCfgLayoutTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/cfg/MethodCfgLayoutTests.scala @@ -1,9 +1,9 @@ package io.joern.c2cpg.passes.cfg import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodCfgLayoutTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/ClassTypeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/ClassTypeTests.scala index c746c034f705..4615364c8b50 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/ClassTypeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/ClassTypeTests.scala @@ -2,10 +2,10 @@ package io.joern.c2cpg.passes.types import io.joern.c2cpg.parser.FileDefaults import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { +class ClassTypeTests extends C2CpgSuite(FileDefaults.CppExt) { "handling C++ classes (code example 1)" should { val cpg = code(""" @@ -64,22 +64,22 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { "handling C++ classes (code example 2)" should { val cpg = code(""" - |class foo : bar { + |class Foo : Bar { | char x; | int y; - | int method () {} + | int method() {} |}; |typedef int mytype;""".stripMargin) - "should contain a type decl for `foo` with correct fields" in { - val List(x) = cpg.typeDecl("foo").l - x.fullName shouldBe "foo" + "should contain a type decl for `Foo` with correct fields" in { + val List(x) = cpg.typeDecl("Foo").l + x.fullName shouldBe "Foo" x.isExternal shouldBe false - x.inheritsFromTypeFullName shouldBe List("bar") + x.inheritsFromTypeFullName shouldBe List("Bar") x.aliasTypeFullName shouldBe None x.order shouldBe 1 x.filename shouldBe "Test0.cpp" - x.filename.endsWith(FileDefaults.CPP_EXT) shouldBe true + x.filename.endsWith(FileDefaults.CppExt) shouldBe true } "should contain type decl for alias `mytype` of `int`" in { @@ -91,7 +91,7 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { x.code shouldBe "typedef int mytype;" x.order shouldBe 2 x.filename shouldBe "Test0.cpp" - x.filename.endsWith(FileDefaults.CPP_EXT) shouldBe true + x.filename.endsWith(FileDefaults.CppExt) shouldBe true } "should contain type decl for external type `int`" in { @@ -105,15 +105,15 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { } "should find exactly 1 internal type" in { - cpg.typeDecl.nameNot(NamespaceTraversal.globalNamespaceName).internal.name.toSetMutable shouldBe Set("foo") + cpg.typeDecl.nameNot(NamespaceTraversal.globalNamespaceName).internal.name.toSetMutable shouldBe Set("Foo") } - "should find five external types (`bar`, `char`, `int`, `void`, `ANY`)" in { - cpg.typeDecl.external.name.toSetMutable shouldBe Set("bar", "char", "int", "void", "ANY") + "should find external type decls" in { + cpg.typeDecl.external.name.sorted.toSetMutable shouldBe Set("ANY", "Bar", "Foo*", "char", "int", "void") } - "should find two members for `foo`: `x` and `y`" in { - cpg.typeDecl.name("foo").member.name.toSetMutable shouldBe Set("x", "y") + "should find two members for `Foo`: `x` and `y`" in { + cpg.typeDecl.name("Foo").member.name.toSetMutable shouldBe Set("x", "y") } "should allow traversing from `int` to its alias `mytype`" in { @@ -122,11 +122,11 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { } "should find one method in type `foo`" in { - cpg.typeDecl.name("foo").method.name.toSetMutable shouldBe Set("method") + cpg.typeDecl.name("Foo").method.name.toSetMutable shouldBe Set("method") } "should allow traversing from type to enclosing file" in { - cpg.typeDecl.file.filter(_.name.endsWith(FileDefaults.CPP_EXT)).l should not be empty + cpg.typeDecl.file.filter(_.name.endsWith(FileDefaults.CppExt)).l should not be empty } } @@ -145,6 +145,7 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { |public: | void foo1() { | b.foo2(); + | B x = b; | } |}; | @@ -156,6 +157,7 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { val List(call) = cpg.call("foo2").l call.methodFullName shouldBe "B.foo2:void()" + cpg.fieldIdentifier.canonicalNameExact("b").inCall.code.l shouldBe List("this->b", "this->b") } } @@ -170,10 +172,55 @@ class ClassTypeTests extends C2CpgSuite(FileDefaults.CPP_EXT) { | ): Bar::Foo(a, b) {} |}""".stripMargin) val List(constructor) = cpg.typeDecl.nameExact("FooT").method.isConstructor.l - constructor.signature shouldBe "Bar.Foo(std.string,Bar.SomeClass)" - val List(p1, p2) = constructor.parameter.l - p1.typ.fullName shouldBe "std.string" - p2.typ.fullName shouldBe "Bar.SomeClass" + constructor.signature shouldBe "Bar.Foo(std.string&,Bar.SomeClass&)" + val List(thisP, p1, p2) = constructor.parameter.l + thisP.name shouldBe "this" + thisP.typeFullName shouldBe "FooT*" + thisP.index shouldBe 0 + p1.typ.fullName shouldBe "std.string&" + p1.index shouldBe 1 + p2.typ.fullName shouldBe "Bar.SomeClass&" + p2.index shouldBe 2 + } + } + + "handling C++ operator definitions" should { + "generate correct fullnames in classes" in { + val cpg = code(""" + |class Foo { + | public: + | void operator delete (void *d) { free(d); } + | bool operator == (const Foo &lhs, const Foo &rhs) { return false; } + | Foo &Foo::operator + (const Foo &lhs, const Foo &rhs) { return null; } + | Foo &Foo::operator() (const Foo &a) { return null; } + | Foo &Foo::operator[] (int index) { return null; } + |} + |Foo &Foo::operator + (const Foo &lhs, const Foo &rhs) + |""".stripMargin) + val List(del, eq, plus, apply, idx) = cpg.typeDecl.nameExact("Foo").method.l + del.name shouldBe "delete" + del.fullName shouldBe "Foo.delete:void(void*)" + eq.name shouldBe "==" + eq.fullName shouldBe "Foo.==:bool(Foo&,Foo&)" + plus.name shouldBe "+" + plus.fullName shouldBe "Foo.+:Foo&(Foo&,Foo&)" + apply.name shouldBe "()" + apply.fullName shouldBe "Foo.():Foo&(Foo&)" + idx.name shouldBe "[]" + idx.fullName shouldBe "Foo.[]:Foo&(int)" + } + + "generate correct fullnames in classes with conversions" in { + val cpg = code(""" + |class Foo { + | enum Kind { A, B, C } kind; + | public: + | operator Kind() const { return kind; } + |}; + |""".stripMargin) + val List(k) = cpg.typeDecl.nameExact("Foo").method.l + k.name shouldBe "Kind" + k.fullName shouldBe "Foo.Kind:Foo.Kind()" } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/EnumTypeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/EnumTypeTests.scala index 27dd1b780dd2..c13688456f29 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/EnumTypeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/EnumTypeTests.scala @@ -6,10 +6,10 @@ import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.codepropertygraph.generated.nodes.FieldIdentifier import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -class EnumTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { +class EnumTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CppExt) { "Enums" should { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/NamespaceTypeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/NamespaceTypeTests.scala index a58938f089aa..941b45bc81f1 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/NamespaceTypeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/NamespaceTypeTests.scala @@ -6,10 +6,10 @@ import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.codepropertygraph.generated.nodes.FieldIdentifier import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { +class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CppExt) { "Namespaces" should { @@ -77,12 +77,10 @@ class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { | // enclosing namespaces are the global namespace, Q, and Q::V |{ return 0; } |""".stripMargin) - inside(cpg.method.nameNot("").fullName.l) { case List(m1, f1, f2, h, m2) => - m1 shouldBe "Q.V.C.m:int()" - f1 shouldBe "Q.V.f:int()" - f2 shouldBe "Q.V.f:int()" + inside(cpg.method.nameNot("").fullName.l) { case List(f, m, h) => + f shouldBe "Q.V.f:int()" + m shouldBe "Q.V.C.m:int()" h shouldBe "h:void()" - m2 shouldBe "Q.V.C.m:int()" } inside(cpg.namespaceBlock.nameNot("").l) { case List(q, v) => @@ -162,10 +160,10 @@ class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { namespaceX.fullName shouldBe "X" } - inside(cpg.method.internal.nameNot("").fullName.l) { case List(f, g, h) => + inside(cpg.method.internal.nameNot("").fullName.l) { case List(h, f, g) => + h shouldBe "h:void()" f shouldBe "f:void()" g shouldBe "A.g:void()" - h shouldBe "h:void()" } inside(cpg.call.filterNot(_.name == Operators.fieldAccess).l) { case List(f, g) => @@ -201,7 +199,7 @@ class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { a2.fullName shouldBe "A" } - inside(cpg.method.internal.nameNot("").l) { case List(f1, f2, foo, bar) => + inside(cpg.method.internal.nameNot("").l) { case List(foo, bar, f1, f2) => f1.fullName shouldBe "A.f:void(int)" f1.signature shouldBe "void(int)" f2.fullName shouldBe "A.f:void(char)" @@ -377,9 +375,7 @@ class NamespaceTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { "FinalClasses.C22", "FinalClasses.C23", "IntermediateClasses.B1", - "IntermediateClasses.B1*", - "IntermediateClasses.B2", - "IntermediateClasses.B2*" + "IntermediateClasses.B2" ) } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/StructTypeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/StructTypeTests.scala index c6cfba93a1d4..8b4b9b76c92f 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/StructTypeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/StructTypeTests.scala @@ -2,7 +2,7 @@ package io.joern.c2cpg.passes.types import io.joern.c2cpg.testfixtures.C2CpgSuite import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class StructTypeTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TemplateTypeTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TemplateTypeTests.scala index 59720a9a3042..459407ce6e2f 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TemplateTypeTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TemplateTypeTests.scala @@ -2,10 +2,10 @@ package io.joern.c2cpg.passes.types import io.joern.c2cpg.parser.FileDefaults import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { +class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CppExt) { "Templates" should { @@ -29,7 +29,7 @@ class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { typeDeclA.aliasTypeFullName shouldBe Option("X") typeDeclB.name shouldBe "B" typeDeclB.fullName shouldBe "B" - typeDeclB.aliasTypeFullName shouldBe Option("Y") + typeDeclB.aliasTypeFullName shouldBe Option("Y") } } @@ -72,10 +72,10 @@ class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CPP_EXT) { |""".stripMargin) inside(cpg.method.nameNot("").internal.l) { case List(x, y) => x.name shouldBe "x" - x.fullName shouldBe "x:void(#0,#1)" + x.fullName shouldBe "x:void(T,U)" x.signature shouldBe "void(T,U)" y.name shouldBe "y" - y.fullName shouldBe "y:void(#0,#1)" + y.fullName shouldBe "y:void(T,U)" y.signature shouldBe "void(T,U)" } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TypeNodePassTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TypeNodePassTests.scala index 01a1ffa70919..f30e98f10f26 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TypeNodePassTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/passes/types/TypeNodePassTests.scala @@ -1,9 +1,9 @@ package io.joern.c2cpg.passes.types import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TypeNodePassTests extends C2CpgSuite { @@ -16,8 +16,33 @@ class TypeNodePassTests extends C2CpgSuite { |""".stripMargin) val List(foo) = cpg.typeDecl.nameExact("foo").l val List(bar) = cpg.typeDecl.nameExact("bar").l - foo.aliasTypeFullName shouldBe Option("char") - bar.aliasTypeFullName shouldBe Option("char") + foo.aliasTypeFullName shouldBe Option("char*") + bar.aliasTypeFullName shouldBe Option("char**") + } + + "be correct for reference to type" in { + val cpg = code( + """ + |typedef const char (&TwoChars)[2]; + |""".stripMargin, + "twochars.cpp" + ) + val List(bar) = cpg.typeDecl.nameExact("TwoChars").l + bar.fullName shouldBe "TwoChars" + bar.aliasTypeFullName shouldBe Option("char(&)[2]") + } + + "be correct for unknown type behind macro" in { + val cpg = code( + """ + |#define DECLARE() unknown *val = NULL + |static void foo() { + | DECLARE(); + |} + |""".stripMargin, + "unknown.cpp" + ) + cpg.local.typeFullName.l shouldBe List("unknown*") } "be correct for static decl assignment" in { @@ -126,22 +151,20 @@ class TypeNodePassTests extends C2CpgSuite { |} |""".stripMargin) inside(cpg.call("free").argument(1).l) { case List(arg) => - arg.evalType.l shouldBe List("test") + arg.evalType.l shouldBe List("test*") arg.code shouldBe "ptr" inside(arg.typ.referencedTypeDecl.l) { case List(tpe) => - tpe.fullName shouldBe "test" - tpe.name shouldBe "test" - tpe.code should startWith("struct test") + tpe.fullName shouldBe "test*" + tpe.name shouldBe "test*" } inside(cpg.local.l) { case List(ptr) => ptr.name shouldBe "ptr" ptr.typeFullName shouldBe "test*" - ptr.code shouldBe "struct test* ptr" + ptr.code shouldBe "struct test *ptr" } inside(cpg.local.typ.referencedTypeDecl.l) { case List(tpe) => - tpe.name shouldBe "test" - tpe.fullName shouldBe "test" - tpe.code should startWith("struct test") + tpe.name shouldBe "test*" + tpe.fullName shouldBe "test*" } } } @@ -169,7 +192,7 @@ class TypeNodePassTests extends C2CpgSuite { |} |""".stripMargin) inside(cpg.local.typ.referencedTypeDecl.l) { case List(tpe) => - tpe.fullName shouldBe "Foo" + tpe.fullName shouldBe "Foo*" } } @@ -190,10 +213,46 @@ class TypeNodePassTests extends C2CpgSuite { inside(cpg.method("test_func").ast.isLocal.name(badChar.name).code(".*\\*.*").l) { case List(ptr) => ptr.name shouldBe "badChar" ptr.typeFullName shouldBe "char*" - ptr.code shouldBe "char* badChar" + ptr.code shouldBe "char * badChar" } } } + + "be correct for volatile types" in { + val cpg = code(""" + |void func(void) { + | static volatile int **ipp; + | static int *ip; + | static volatile int i = 0; + | + | ipp = &ip; + | ipp = (int**) &ip; + | *ipp = &i; + | if (*ip != 0) {} + |}""".stripMargin) + cpg.identifier.nameExact("ipp").typeFullName.distinct.l shouldBe List("volatile int**") + cpg.identifier.nameExact("ip").typeFullName.distinct.l shouldBe List("int*") + cpg.identifier.nameExact("i").typeFullName.distinct.l shouldBe List("volatile int") + cpg.local.nameExact("ipp").typeFullName.l shouldBe List("volatile int**") + cpg.local.nameExact("ip").typeFullName.l shouldBe List("int*") + cpg.local.nameExact("i").typeFullName.l shouldBe List("volatile int") + } + + "be correct for referenced types from locals" in { + val cpg = code(""" + |struct flex { + | int a; + | char b[]; + |}; + |void foo() { + | struct flex *ptr = malloc(sizeof(struct flex)); + | struct flex value = {0}; + |}""".stripMargin) + val List(value) = cpg.typeDecl.fullNameExact("flex").referencingType.fullNameExact("flex").localOfType.l + value.name shouldBe "value" + value.typeFullName shouldBe "flex" + value.code shouldBe "struct flex value" + } } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/AstQueryTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/AstQueryTests.scala index b77feba015d4..f3b9675ec390 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/AstQueryTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/AstQueryTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.querying import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AstQueryTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/CfgQueryTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/CfgQueryTests.scala index f66694c2ceb2..8c31b9ce4a69 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/CfgQueryTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/CfgQueryTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.querying import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CfgQueryTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/DdgCfgQueryTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/DdgCfgQueryTests.scala index befa06bf7d00..f2e558ed8dd9 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/DdgCfgQueryTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/DdgCfgQueryTests.scala @@ -2,8 +2,8 @@ package io.joern.c2cpg.querying import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* class DdgCfgQueryTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocalQueryTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocalQueryTests.scala index d858744b122e..e03416bdd4ae 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocalQueryTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocalQueryTests.scala @@ -1,85 +1,126 @@ package io.joern.c2cpg.querying import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** Language primitives for navigating local variables */ class LocalQueryTests extends C2CpgSuite { - private val cpg = code(""" - | struct node { - | int value; - | struct node *next; - | }; - | - | void free_list(struct node *head) { - | struct node *q; - | for (struct node *p = head; p != NULL; p = q) { - | q = p->next; - | free(p); - | } - | } - | - | int flow(int p0) { - | int a = p0; - | int b = a; - | int c = 0x31; - | int z = b + c; - | z++; - | int x = z; - | return x; - | } - | - | void test() { - | static int a, b, c; - | wchar_t *foo; - | int d[10], e = 1; - | } - | """.stripMargin) + "local query example 1" should { + "allow to query for the locals" in { + val cpg = code( + """ + |void foo() { + | static const Foo::Bar bar{}; + | static extern std::vector vec; + |} + |""".stripMargin, + "test.cpp" + ) + val List(barLocal) = cpg.method.name("foo").local.nameExact("bar").l + barLocal.typeFullName shouldBe "Foo.Bar" + barLocal.code shouldBe "static const Foo::Bar bar" - "should allow to query for all locals" in { - cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo") + val List(vecLocal) = cpg.method.name("foo").local.nameExact("vec").l + vecLocal.typeFullName shouldBe "std.vector" + vecLocal.code shouldBe "static extern std::vector vec" + } } - "should prove correct (name, type) pairs for locals" in { - inside(cpg.method.name("free_list").local.l) { case List(q, p) => - q.name shouldBe "q" - q.typeFullName shouldBe "node*" - q.code shouldBe "struct node* q" - p.name shouldBe "p" - p.typeFullName shouldBe "node*" - p.code shouldBe "struct node* p" + "local query example 2" should { + "allow to query for the local" in { + val cpg = code( + """ + |class Foo { + | static Foo* foo() { + | static Foo bar; + | return &bar; + | } + |} + |""".stripMargin, + "test.cpp" + ) + val List(barLocal) = cpg.method.name("foo").local.nameExact("bar").l + barLocal.typeFullName shouldBe "Foo" + barLocal.code shouldBe "static Foo bar" } } - "should prove correct (name, type, code) pairs for locals" in { - inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) => - a.name shouldBe "a" - a.typeFullName shouldBe "int" - a.code shouldBe "static int a" - b.name shouldBe "b" - b.typeFullName shouldBe "int" - b.code shouldBe "static int b" - c.name shouldBe "c" - c.typeFullName shouldBe "int" - c.code shouldBe "static int c" - foo.name shouldBe "foo" - foo.typeFullName shouldBe "wchar_t*" - foo.code shouldBe "wchar_t* foo" - d.name shouldBe "d" - d.typeFullName shouldBe "int[10]" - d.code shouldBe "int[10] d" - e.name shouldBe "e" - e.typeFullName shouldBe "int" - e.code shouldBe "int e" + "local query example 3" should { + val cpg = code(""" + | struct node { + | int value; + | struct node *next; + | }; + | + | void free_list(struct node *head) { + | struct node *q; + | for (struct node *p = head; p != NULL; p = q) { + | q = p->next; + | free(p); + | } + | } + | + | int flow(int p0) { + | int a = p0; + | int b = a; + | int c = 0x31; + | int z = b + c; + | z++; + | int x = z; + | return x; + | } + | + | void test() { + | static int a, *b, c[1]; + | wchar_t *foo; + | int d[10], e = 1; + | } + | """.stripMargin) + + "should allow to query for all locals" in { + cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo") } - } - "should allow finding filenames by local regex" in { - val filename = cpg.local.name("a*").file.name.headOption - filename should not be empty - filename.head.endsWith(".c") shouldBe true - } + "should prove correct (name, type) pairs for locals" in { + inside(cpg.method.name("free_list").local.l) { case List(q, p) => + q.name shouldBe "q" + q.typeFullName shouldBe "node*" + q.code shouldBe "struct node *q" + p.name shouldBe "p" + p.typeFullName shouldBe "node*" + p.code shouldBe "struct node *p" + } + } + "should prove correct (name, type, code) pairs for locals" in { + inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) => + a.name shouldBe "a" + a.typeFullName shouldBe "int" + a.code shouldBe "static int a" + b.name shouldBe "b" + b.typeFullName shouldBe "int*" + b.code shouldBe "static int *b" + c.name shouldBe "c" + c.typeFullName shouldBe "int[1]" + c.code shouldBe "static int c[1]" + foo.name shouldBe "foo" + foo.typeFullName shouldBe "wchar_t*" + foo.code shouldBe "wchar_t *foo" + d.name shouldBe "d" + d.typeFullName shouldBe "int[10]" + d.code shouldBe "int d[10]" + e.name shouldBe "e" + e.typeFullName shouldBe "int" + e.code shouldBe "int e" + } + } + + "should allow finding filenames by local regex" in { + val filename = cpg.local.name("a*").file.name.headOption + filename should not be empty + filename.head.endsWith(".c") shouldBe true + } + } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocationQueryTests.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocationQueryTests.scala index a4bc22e264c4..7ac71e73f9fe 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocationQueryTests.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/LocationQueryTests.scala @@ -1,7 +1,7 @@ package io.joern.c2cpg.querying import io.joern.c2cpg.testfixtures.C2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocationQueryTests extends C2CpgSuite { diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgFrontend.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgFrontend.scala index bc837309b80f..71d522a20ae0 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgFrontend.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgFrontend.scala @@ -3,7 +3,9 @@ package io.joern.c2cpg.testfixtures import better.files.File import io.joern.c2cpg.Config import io.joern.c2cpg.passes.AstCreationPass +import io.joern.c2cpg.passes.FunctionDeclNodePass import io.joern.x2cpg.testfixtures.LanguageFrontend +import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.X2Cpg.newEmptyCpg import io.shiftleft.codepropertygraph.generated.Cpg @@ -19,6 +21,8 @@ trait AstC2CpgFrontend extends LanguageFrontend { .withOutputPath(pathAsString) val astCreationPass = new AstCreationPass(cpg, config) astCreationPass.createAndApply() + new FunctionDeclNodePass(cpg, astCreationPass.unhandledMethodDeclarations())(ValidationMode.Enabled) + .createAndApply() cpg } } diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgSuite.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgSuite.scala index 94880aa12b0c..be2dcee2127d 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgSuite.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/AstC2CpgSuite.scala @@ -3,4 +3,4 @@ package io.joern.c2cpg.testfixtures import io.joern.c2cpg.parser.FileDefaults import io.joern.x2cpg.testfixtures.Code2CpgFixture -class AstC2CpgSuite(fileSuffix: String = FileDefaults.C_EXT) extends Code2CpgFixture(() => new CAstTestCpg(fileSuffix)) +class AstC2CpgSuite(fileSuffix: String = FileDefaults.CExt) extends Code2CpgFixture(() => new CAstTestCpg(fileSuffix)) diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/C2CpgSuite.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/C2CpgSuite.scala index 95cb9238cfcc..8f99314d9aca 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/C2CpgSuite.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/C2CpgSuite.scala @@ -1,17 +1,18 @@ package io.joern.c2cpg.testfixtures import io.joern.c2cpg.parser.FileDefaults -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.x2cpg.testfixtures.Code2CpgFixture class C2CpgSuite( - fileSuffix: String = FileDefaults.C_EXT, + fileSuffix: String = FileDefaults.CExt, withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), withPostProcessing: Boolean = false ) extends Code2CpgFixture(() => new CDefaultTestCpg(fileSuffix) .withOssDataflow(withOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/CCfgTestCpg.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/CCfgTestCpg.scala index e7fb53ce6f68..95de90ca209d 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/CCfgTestCpg.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/CCfgTestCpg.scala @@ -3,4 +3,4 @@ package io.joern.c2cpg.testfixtures import io.joern.c2cpg.parser.FileDefaults import io.joern.x2cpg.testfixtures.CfgTestCpg -class CCfgTestCpg(override val fileSuffix: String = FileDefaults.C_EXT) extends CfgTestCpg with C2CpgFrontend {} +class CCfgTestCpg(override val fileSuffix: String = FileDefaults.CExt) extends CfgTestCpg with C2CpgFrontend {} diff --git a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/DataFlowCodeToCpgSuite.scala b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/DataFlowCodeToCpgSuite.scala index 16211eb85508..fe37bafd99cc 100644 --- a/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/DataFlowCodeToCpgSuite.scala +++ b/joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/testfixtures/DataFlowCodeToCpgSuite.scala @@ -11,7 +11,7 @@ import io.joern.x2cpg.testfixtures.TestCpg import io.shiftleft.semanticcpg.layers.LayerCreatorContext class DataFlowTestCpg extends TestCpg with C2CpgFrontend { - override val fileSuffix: String = FileDefaults.C_EXT + override val fileSuffix: String = FileDefaults.CExt override def applyPasses(): Unit = { X2Cpg.applyDefaultOverlays(this) @@ -27,7 +27,7 @@ class DataFlowCodeToCpgSuite extends Code2CpgFixture(() => new DataFlowTestCpg() protected implicit val context: EngineContext = EngineContext() protected def flowToResultPairs(path: Path): List[(String, Integer)] = - path.resultPairs().collect { case (firstElement: String, secondElement: Option[Integer]) => + path.resultPairs().collect { case (firstElement: String, secondElement) => (firstElement, secondElement.getOrElse(-1)) } } diff --git a/joern-cli/frontends/csharpsrc2cpg/build.sbt b/joern-cli/frontends/csharpsrc2cpg/build.sbt index a112ff51dcac..3be353350ffe 100644 --- a/joern-cli/frontends/csharpsrc2cpg/build.sbt +++ b/joern-cli/frontends/csharpsrc2cpg/build.sbt @@ -76,16 +76,15 @@ astGenDlTask := { astGenDir.mkdirs() astGenBinaryNames.value.foreach { fileName => - DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", astGenDir / fileName) + val file = astGenDir / fileName + DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", file) + // permissions are lost during the download; need to set them manually + file.setExecutable(true, false) } val distDir = (Universal / stagingDirectory).value / "bin" / "astgen" distDir.mkdirs() - IO.copyDirectory(astGenDir, distDir) - - // permissions are lost during the download; need to set them manually - astGenDir.listFiles().foreach(_.setExecutable(true, false)) - distDir.listFiles().foreach(_.setExecutable(true, false)) + IO.copyDirectory(astGenDir, distDir, preserveExecutable = true) } Compile / compile := ((Compile / compile) dependsOn astGenDlTask).value @@ -100,3 +99,7 @@ stage := Def Universal / packageName := name.value Universal / topLevelDirectory := None + +/** write the astgen version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("DotNet-AstGen-Version") -> astGenVersion.value) diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf b/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf index f9f979f2f3be..2a1ed9f7b1f0 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf @@ -1,3 +1,3 @@ csharpsrc2cpg { - dotnetastgen_version: "0.34.0" + dotnetastgen_version: "0.39.0" } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala index 096ccca8c4d4..b954b6936beb 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala @@ -1,11 +1,17 @@ package io.joern.csharpsrc2cpg import better.files.File +import io.joern.csharpsrc2cpg.CSharpSrc2Cpg.findBuildFiles import io.joern.csharpsrc2cpg.astcreation.AstCreator import io.joern.csharpsrc2cpg.datastructures.CSharpProgramSummary import io.joern.csharpsrc2cpg.parser.DotNetJsonParser import io.joern.csharpsrc2cpg.passes.{AstCreationPass, DependencyPass} -import io.joern.csharpsrc2cpg.utils.{DependencyDownloader, DotNetAstGenRunner} +import io.joern.csharpsrc2cpg.utils.{ + DependencyDownloader, + DotNetAstGenRunner, + ImplicitUsingsCollector, + ProgramSummaryCreator +} import io.joern.x2cpg.X2Cpg.withNewEmptyCpg import io.joern.x2cpg.astgen.AstGenRunner.AstGenRunnerResult import io.joern.x2cpg.astgen.ParserResult @@ -35,33 +41,21 @@ class CSharpSrc2Cpg extends X2CpgFrontend[Config] { File.usingTemporaryDirectory("csharpsrc2cpgOut") { tmpDir => val astGenResult = new DotNetAstGenRunner(config).execute(tmpDir) val astCreators = CSharpSrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config) - // Pre-parse the AST creators for high level structures - val internalProgramSummary = ConcurrentTaskUtil - .runUsingThreadPool(astCreators.map(x => () => x.summarize()).iterator) - .flatMap { - case Failure(exception) => logger.warn(s"Unable to pre-parse C# file, skipping - ", exception); None - case Success(summary) => Option(summary) - } - .foldLeft(CSharpProgramSummary(imports = CSharpProgramSummary.initialImports))(_ ++= _) - - val builtinSummary = CSharpProgramSummary( - mutable.Map - .fromSpecific(CSharpProgramSummary.BuiltinTypes.view.filterKeys(internalProgramSummary.imports(_))) - .result() - ) - - val internalAndBuiltinSummary = internalProgramSummary ++= builtinSummary + val buildFiles = findBuildFiles(config) + val localSummary = ProgramSummaryCreator + .from(astCreators, config) + .addGlobalImports(ImplicitUsingsCollector.collect(buildFiles).toSet) val hash = HashUtil.sha256(astCreators.map(_.parserResult).map(x => Paths.get(x.fullPath))) new MetaDataPass(cpg, Languages.CSHARPSRC, config.inputPath, Option(hash)).createAndApply() val packageIds = mutable.HashSet.empty[String] - new DependencyPass(cpg, buildFiles(config), packageIds.add).createAndApply() + new DependencyPass(cpg, buildFiles, packageIds.add).createAndApply() // If "download dependencies" is enabled, then fetch dependencies and resolve their symbols for additional types val programSummary = if (config.downloadDependencies) { - DependencyDownloader(cpg, config, internalAndBuiltinSummary, packageIds.toSet).download() + DependencyDownloader(cpg, config, localSummary, packageIds.toSet).download() } else { - internalAndBuiltinSummary + localSummary } new AstCreationPass(cpg, astCreators.map(_.withSummary(programSummary)), report).createAndApply() TypeNodePass.withTypesFromCpg(cpg).createAndApply() @@ -70,16 +64,6 @@ class CSharpSrc2Cpg extends X2CpgFrontend[Config] { } } - private def buildFiles(config: Config): List[String] = { - SourceFiles.determine( - config.inputPath, - Set(".csproj"), - Option(config.defaultIgnoredFilesRegex), - Option(config.ignoredFilesRegex), - Option(config.ignoredFiles) - ) - } - } object CSharpSrc2Cpg { @@ -108,6 +92,16 @@ object CSharpSrc2Cpg { ) } + def findBuildFiles(config: Config): List[String] = { + SourceFiles.determine( + config.inputPath, + Set(".csproj"), + Option(config.defaultIgnoredFilesRegex), + Option(config.ignoredFilesRegex), + Option(config.ignoredFiles) + ) + } + /** Addresses behaviour in Windows where a user-specific temp folder is used: parserResult.fullPath = * C:\Users\runneradmin\AppData\Local\Temp\... config.inputPath = C:\Users\RUNNER~1\AppData\Local\Temp\... * diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Constants.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Constants.scala index 4fabe8274b0b..1a59f4782a8a 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Constants.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Constants.scala @@ -1,8 +1,9 @@ package io.joern.csharpsrc2cpg object Constants { - val This: String = "this" - val Global: String = "global" + val This: String = "this" + val Global: String = "global" + val TopLevelMainMethodName: String = "
$" } object CSharpOperators { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Main.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Main.scala index 2333179302cb..73c756627787 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Main.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/Main.scala @@ -5,13 +5,17 @@ import io.joern.x2cpg.astgen.AstGenConfig import io.joern.x2cpg.passes.frontend.{TypeRecoveryParserConfig, XTypeRecovery, XTypeRecoveryConfig} import io.joern.x2cpg.utils.Environment import io.joern.x2cpg.{DependencyDownloadConfig, X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import org.slf4j.LoggerFactory import scopt.OParser import java.nio.file.Paths -final case class Config(downloadDependencies: Boolean = false) - extends X2CpgConfig[Config] +final case class Config( + downloadDependencies: Boolean = false, + useBuiltinSummaries: Boolean = true, + externalSummaryPaths: Set[String] = Set.empty +) extends X2CpgConfig[Config] with DependencyDownloadConfig[Config] with TypeRecoveryParserConfig[Config] with AstGenConfig[Config] { @@ -23,6 +27,14 @@ final case class Config(downloadDependencies: Boolean = false) copy(downloadDependencies = value).withInheritedFields(this) } + def withUseBuiltinSummaries(value: Boolean): Config = { + copy(useBuiltinSummaries = value).withInheritedFields(this) + } + + def withExternalSummaryPaths(paths: Set[String]): Config = { + copy(externalSummaryPaths = paths).withInheritedFields(this) + } + } object Frontend { @@ -34,22 +46,33 @@ object Frontend { OParser.sequence( programName("csharpsrc2cpg"), DependencyDownloadConfig.parserOptions, - XTypeRecoveryConfig.parserOptionsForParserConfig + XTypeRecoveryConfig.parserOptionsForParserConfig, + opt[Unit]("disable-builtin-summaries") + .text("do not use the built-in type summaries") + .action((_, c) => c.withUseBuiltinSummaries(false)), + opt[Seq[String]]("external-summary-paths") + .text("where to look for external type summaries produced by DotNetAstGen (comma-separated list of paths)") + .action((paths, c) => c.withExternalSummaryPaths(c.externalSummaryPaths ++ paths)) ) } } -object Main extends X2CpgMain(cmdLineParser, new CSharpSrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new CSharpSrc2Cpg()) with FrontendHTTPServer[Config, CSharpSrc2Cpg] { private val logger = LoggerFactory.getLogger(getClass) + override protected def newDefaultConfig(): Config = Config() + def run(config: Config, csharpsrc2cpg: CSharpSrc2Cpg): Unit = { - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - if (Environment.pathExists(absPath)) { - csharpsrc2cpg.run(config.withInputPath(absPath)) - } else { - logger.warn(s"Given path '$absPath' does not exist, skipping") + if (config.serverMode) { startup() } + else { + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + if (Environment.pathExists(absPath)) { + csharpsrc2cpg.run(config.withInputPath(absPath)) + } else { + logger.warn(s"Given path '$absPath' does not exist, skipping") + } } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreator.scala index 90291c57fcb6..90b5c6ddda5c 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreator.scala @@ -1,16 +1,24 @@ package io.joern.csharpsrc2cpg.astcreation -import io.joern.csharpsrc2cpg.{CSharpDefines, Constants} -import io.joern.csharpsrc2cpg.datastructures.{CSharpProgramSummary, CSharpScope} +import io.joern.csharpsrc2cpg.Constants +import io.joern.csharpsrc2cpg.datastructures.{CSharpProgramSummary, CSharpScope, MethodScope, TypeScope} import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} +import io.joern.csharpsrc2cpg.utils.Utils.* import io.joern.x2cpg.astgen.{AstGenNodeBuilder, ParserResult} +import io.joern.x2cpg.utils.NodeBuilders.newModifierNode import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode} -import io.shiftleft.codepropertygraph.generated.NodeTypes -import io.shiftleft.codepropertygraph.generated.nodes.{NewFile, NewTypeDecl} -import io.shiftleft.passes.IntervalKeyPool +import io.shiftleft.codepropertygraph.generated.{DiffGraphBuilder, ModifierTypes, NodeTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{ + NewBlock, + NewFile, + NewMethod, + NewMethodParameterIn, + NewMethodReturn, + NewModifier, + NewTypeDecl +} import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate.DiffGraphBuilder import ujson.Value import java.math.BigInteger @@ -49,9 +57,78 @@ class AstCreator( } protected def astForCompilationUnit(cu: DotNetNodeInfo): Seq[Ast] = { - val imports = cu.json(ParserKeys.Usings).arr.flatMap(astForNode).toSeq - val memberAsts = astForMembers(cu.json(ParserKeys.Members).arr.map(createDotNetNodeInfo).toSeq) - imports ++ memberAsts + val importAsts = cu.json(ParserKeys.Usings).arr.flatMap(astForNode).toSeq + val members = cu.json(ParserKeys.Members).arr.map(createDotNetNodeInfo).toSeq + val (globalStatements, nonGlobalStatements) = members.partition(_.node == GlobalStatement) + val nonGlobalStatementAsts = nonGlobalStatements.flatMap(astForNode) + + // If there are global statements, we should treat this file as an entry-point. + // Roslyn implicitly wraps these statements inside the following block: + // ``` + // internal class Program { + // private static void
$(string[] args) { + // + // } + // } + // ``` + // Note: there can only be one such file in a given project, but we are currently + // not checking this. + if (globalStatements.nonEmpty) { + importAsts ++ astForTopLevelStatements(globalStatements) ++ nonGlobalStatementAsts + } else { + importAsts ++ nonGlobalStatementAsts + } + } + + private def astForTopLevelStatements(topLevelStmts: Seq[DotNetNodeInfo]): Seq[Ast] = { + val className = composeTopLevelClassName(relativeFileName) + val classFullName = className + val mainName = Constants.TopLevelMainMethodName + val mainParameters = List(("args", "System.String[]")) + val voidType = BuiltinTypes.DotNetTypeMap(BuiltinTypes.Void) + val mainSignature = composeMethodLikeSignature(voidType, mainParameters.map(_._2)) + val mainFullName = composeMethodFullName(classFullName, mainName, mainSignature) + + val classNode = NewTypeDecl() + .name(className) + .fullName(classFullName) + .filename(relativeFileName) + + val classModifiers = newModifierNode(ModifierTypes.INTERNAL) :: Nil + val methodNode = NewMethod() + .name(mainName) + .fullName(mainFullName) + .filename(relativeFileName) + .signature(mainSignature) + + val methodModifiers = newModifierNode(ModifierTypes.STATIC) :: newModifierNode(ModifierTypes.PRIVATE) :: Nil + val argsParameters = mainParameters.map(x => NewMethodParameterIn().name(x._1).typeFullName(x._2)) + val methodBlock = NewBlock().typeFullName(voidType) + val methodReturn = NewMethodReturn().typeFullName(methodBlock.typeFullName) + + val topLevelStmtAsts = { + scope.pushNewScope(TypeScope(classFullName)) + scope.pushNewScope(MethodScope(mainFullName)) + argsParameters.foreach(x => scope.addToScope(x.name, x)) + + val asts = topLevelStmts.flatMap(astForNode) + + scope.popScope() + scope.popScope() + asts + } + + val methodAst = Ast(methodNode) + .withChildren(methodModifiers.map(Ast(_))) + .withChild(Ast(argsParameters)) + .withChild(Ast(methodBlock).withChildren(topLevelStmtAsts)) + .withChild(Ast(methodReturn)) + + val classAst = Ast(classNode) + .withChildren(classModifiers.map(Ast(_))) + .withChild(methodAst) + + Seq(classAst) } protected def astForMembers(members: Seq[DotNetNodeInfo]): Seq[Ast] = { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala index a070f8c8b539..f56f7645895e 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala @@ -2,15 +2,17 @@ package io.joern.csharpsrc2cpg.astcreation import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetJsonAst, DotNetNodeInfo, ParserKeys} +import io.joern.csharpsrc2cpg.utils.Utils.{withoutSignature} import io.joern.csharpsrc2cpg.{CSharpDefines, Constants, astcreation} +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.{Ast, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, PropertyNames} -import io.shiftleft.passes.IntervalKeyPool +import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} import ujson.Value import scala.annotation.tailrec import scala.util.{Failure, Success, Try} + trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => private val anonymousTypeKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) @@ -29,15 +31,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } - def createCallNodeForOperator( - node: DotNetNodeInfo, - operatorMethod: String, - signature: Option[String] = None, - typeFullName: Option[String] = None - ): NewCall = { - callNode(node, node.code, operatorMethod, operatorMethod, DispatchTypes.STATIC_DISPATCH, signature, typeFullName) - } - protected def notHandledYet(node: DotNetNodeInfo): Seq[Ast] = { val text = s"""Node type '${node.node}' not handled yet! @@ -52,7 +45,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As protected def astFullName(node: DotNetNodeInfo): String = { scope.surroundingScopeFullName match - case Some(fullName) => s"$fullName.${nameFromNode(node)}" + case Some(fullName) => s"${withoutSignature(fullName)}.${nameFromNode(node)}" case _ => nameFromNode(node) } @@ -83,7 +76,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As case x: NewMethodParameterIn => identifierNode(dotNetNode.orNull, x.name, x.code, x.typeFullName, x.dynamicTypeHintFullName) case x => - logger.warn(s"Unhandled declaration type '${x.label()}' for ${x.name}") + logger.warn(s"Unhandled declaration type '${x.label}' for ${x.name}") identifierNode(dotNetNode.orNull, x.name, x.name, Defines.Any) } @@ -98,6 +91,36 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As Operators.lessEqualsThan -> BuiltinTypes.DotNetTypeMap(BuiltinTypes.Bool) ) + protected val binaryOperatorsMap: Map[String, String] = Map( + "+" -> Operators.addition, + "-" -> Operators.subtraction, + "*" -> Operators.multiplication, + "/" -> Operators.division, + "%" -> Operators.modulo, + "==" -> Operators.equals, + "!=" -> Operators.notEquals, + "&&" -> Operators.logicalAnd, + "||" -> Operators.logicalOr, + "=" -> Operators.assignment, + "+=" -> Operators.assignmentPlus, + "-=" -> Operators.assignmentMinus, + "*=" -> Operators.assignmentMultiplication, + "/=" -> Operators.assignmentDivision, + "%=" -> Operators.assignmentModulo, + "&=" -> Operators.assignmentAnd, + "|=" -> Operators.assignmentOr, + "^=" -> Operators.assignmentXor, + ">>=" -> Operators.assignmentLogicalShiftRight, + "<<=" -> Operators.assignmentShiftLeft, + ">" -> Operators.greaterThan, + "<" -> Operators.lessThan, + ">=" -> Operators.greaterEqualsThan, + "<=" -> Operators.lessEqualsThan, + "|" -> Operators.or, + "&" -> Operators.and, + "^" -> Operators.xor + ) + protected def nodeTypeFullName(node: DotNetNodeInfo): String = { node.node match { case NumericLiteralExpression if node.code.matches("^\\d+$") => // e.g. 200 @@ -133,12 +156,22 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As case NullableType => val elementTypeNode = createDotNetNodeInfo(node.json(ParserKeys.ElementType)) nodeTypeFullName(elementTypeNode) + case QualifiedName => + val left = nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Left))) + val right = nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Right))) + s"$left.$right" case IdentifierName => val typeString = nameFromNode(node) scope .tryResolveTypeReference(typeString) .map(_.name) .orElse(BuiltinTypes.DotNetTypeMap.get(typeString)) + .orElse(scope.findFieldInScope(typeString).map(_.typeFullName)) + .orElse(scope.lookupVariable(typeString).flatMap { + case x: NewLocal => Some(x.typeFullName) + case x: NewMethodParameterIn => Some(x.typeFullName) + case _ => None + }) .getOrElse(typeString) case Attribute => val typeString = s"${nameFromNode(node)}Attribute" @@ -196,7 +229,7 @@ object AstCreatorHelper { case SimpleMemberAccessExpression | MemberBindingExpression | SuppressNullableWarningExpression | Attribute => nameFromIdentifier(createDotNetNodeInfo(node.json(ParserKeys.Name))) case ObjectCreationExpression | CastExpression => nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Type))) - case ThisExpression => "this" + case ThisExpression => Constants.This case _ => "" } @@ -279,6 +312,6 @@ object BuiltinTypes { String -> "System.String", Dynamic -> "System.Object", Null -> Null, - Void -> Void + Void -> "System.Void" ) } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForDeclarationsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForDeclarationsCreator.scala index 9eca89334eba..c9129e03c39a 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForDeclarationsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForDeclarationsCreator.scala @@ -1,11 +1,18 @@ package io.joern.csharpsrc2cpg.astcreation -import io.joern.csharpsrc2cpg.CSharpModifiers +import io.joern.csharpsrc2cpg.{CSharpModifiers, Constants} +import io.joern.csharpsrc2cpg.astcreation.AstParseLevel.FULL_AST import io.joern.csharpsrc2cpg.astcreation.BuiltinTypes.DotNetTypeMap import io.joern.csharpsrc2cpg.datastructures.* import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} -import io.joern.csharpsrc2cpg.utils.Utils.{composeMethodFullName, composeMethodLikeSignature} +import io.joern.csharpsrc2cpg.utils.Utils.{ + composeGetterName, + composeMethodFullName, + composeMethodLikeSignature, + composeSetterName, + withoutSignature +} import io.joern.x2cpg.utils.NodeBuilders.{newMethodReturnNode, newModifierNode} import io.joern.x2cpg.{Ast, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.* @@ -70,11 +77,9 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { typeDeclNode(classDecl, name, fullName, relativeFileName, code(classDecl), inherits = inheritsFromTypeFullName) scope.pushNewScope(TypeScope(fullName)) val modifiers = astForModifiers(classDecl) - val members = astForMembers(classDecl.json(ParserKeys.Members).arr.map(createDotNetNodeInfo).toSeq) - - // TODO: Check if any explicit constructor / static constructor decls exists, - // if it doesn't, need to add in default constructor and static constructor and - // pull all field initializations into them. + val members = astForMembers(classDecl.json(ParserKeys.Members).arr.map(createDotNetNodeInfo).toSeq) + ++ addConstructorWithFieldInitializationsIfNeeded(fullName) + ++ addStaticConstructorWithFieldInitializationsIfNeeded(fullName) scope.popScope() val typeDeclAst = Ast(typeDecl) @@ -84,6 +89,86 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { Seq(typeDeclAst) } + private def addConstructorWithFieldInitializationsIfNeeded(typeDeclFullName: String): Seq[Ast] = { + val dynamicFields = scope.getFieldsInScope.filter(f => !f.isStatic && f.isInitialized) + val hasExplicitCtor = + scope.tryResolveTypeReference(typeDeclFullName).exists(_.methods.exists(_.name == Defines.ConstructorMethodName)) + // We should only create the constructor when we are the FULL_AST parseLevel. Otherwise, hasExplicitCtor will + // not be accurate. + val shouldBuildCtor = dynamicFields.nonEmpty && !hasExplicitCtor && parseLevel == FULL_AST + + if (shouldBuildCtor) { + val methodReturn = newMethodReturnNode(DotNetTypeMap(BuiltinTypes.Void), None, None, None) + val signature = composeMethodLikeSignature(methodReturn.typeFullName) + val modifiers = Seq(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.INTERNAL)) + val name = Defines.ConstructorMethodName + val fullName = composeMethodFullName(typeDeclFullName, name, signature) + + val body = { + scope.pushNewScope(MethodScope(fullName)) + val fieldInitAssignmentAsts = astVariableDeclarationForInitializedFields(dynamicFields) + scope.popScope() + Ast(NewBlock().typeFullName(Defines.Any)).withChildren(fieldInitAssignmentAsts) + } + + val methodNode_ = NewMethod() + .name(name) + .fullName(fullName) + .signature(signature) + .filename(relativeFileName) + + val parameterNodes = Seq( + NewMethodParameterIn() + .name(Constants.This) + .code(Constants.This) + .typeFullName(typeDeclFullName) + .evaluationStrategy(EvaluationStrategies.BY_SHARING.name) + .isVariadic(false) + .index(0) + ) + + methodAst(methodNode_, parameterNodes.map(Ast(_)), body, methodReturn, modifiers) :: Nil + } else { + Seq.empty + } + } + + private def addStaticConstructorWithFieldInitializationsIfNeeded(typeDeclFullname: String): Seq[Ast] = { + val staticFields = scope.getFieldsInScope.filter(f => f.isStatic && f.isInitialized) + val hasExplicitCtor = + scope.tryResolveTypeReference(typeDeclFullname).exists(_.methods.exists(_.name == Defines.StaticInitMethodName)) + val shouldBuildCtor = staticFields.nonEmpty && !hasExplicitCtor && parseLevel == FULL_AST + + if (shouldBuildCtor) { + val methodReturn = newMethodReturnNode(DotNetTypeMap(BuiltinTypes.Void), None, None, None) + val signature = composeMethodLikeSignature(methodReturn.typeFullName) + val modifiers = Seq( + newModifierNode(ModifierTypes.CONSTRUCTOR), + newModifierNode(ModifierTypes.INTERNAL), + newModifierNode(ModifierTypes.STATIC) + ) + val name = Defines.StaticInitMethodName + val fullName = composeMethodFullName(typeDeclFullname, name, signature) + + val body = { + scope.pushNewScope(MethodScope(fullName)) + val fieldInitAssignmentAsts = astVariableDeclarationForInitializedFields(staticFields) + scope.popScope() + Ast(NewBlock().typeFullName(Defines.Any)).withChildren(fieldInitAssignmentAsts) + } + + val methodNode_ = NewMethod() + .name(name) + .fullName(fullName) + .signature(signature) + .filename(relativeFileName) + + methodAst(methodNode_, Nil, body, methodReturn, modifiers) :: Nil + } else { + Nil + } + } + protected def astForRecordDeclaration(recordDecl: DotNetNodeInfo): Seq[Ast] = { val name = nameFromNode(recordDecl) val fullName = astFullName(recordDecl) @@ -166,13 +251,9 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForFieldDeclaration(fieldDecl: DotNetNodeInfo): Seq[Ast] = { - val isStatic = fieldDecl - .json(ParserKeys.Modifiers) - .arr - .flatMap(astForModifier) - .flatMap(_.root) - .collectFirst { case x: NewModifier => x.modifierType } - .contains(ModifierTypes.STATIC) + val modifiers = modifiersForNode(fieldDecl) + val isStatic = modifiers.exists(_.modifierType == ModifierTypes.STATIC) + val modifierAsts = modifiers.map(Ast(_)) val declarationNode = createDotNetNodeInfo(fieldDecl.json(ParserKeys.Declaration)) val declAsts = astForVariableDeclaration(declarationNode, isStatic) @@ -185,7 +266,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val memberNodes = declAsts .flatMap(_.nodes.collectFirst { case x: NewIdentifier => x }) .map(x => memberNode(declarationNode, x.name, code(declarationNode), x.typeFullName)) - memberNodes.map(Ast(_).withChildren(annotationAsts).withChildren(astForModifiers(fieldDecl))) + memberNodes.map(Ast(_).withChildren(annotationAsts).withChildren(modifierAsts)) } protected def astForLocalDeclarationStatement(localDecl: DotNetNodeInfo): Seq[Ast] = { @@ -280,20 +361,12 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { .toSeq // TODO: Decide on proper return type for constructors. No `ReturnType` key in C# JSON for constructors so just // defaulted to void (same as java) for now - val methodReturn = newMethodReturnNode(BuiltinTypes.Void, None, None, None) - val signature = composeMethodLikeSignature( - BuiltinTypes.Void, - params.flatMap(_.nodes.collectFirst { case x: NewMethodParameterIn => x.typeFullName }) - ) + val methodReturn = newMethodReturnNode(DotNetTypeMap(BuiltinTypes.Void), None, None, None) + val signature = composeMethodLikeSignature(DotNetTypeMap(BuiltinTypes.Void), params) val typeDeclFullName = scope.surroundingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace); - val modifiers = - (astForModifiers(constructorDecl) :+ Ast(newModifierNode(ModifierTypes.CONSTRUCTOR))) - .flatMap(_.nodes) - .collect { case x: NewModifier => - x - } - .filter(_.modifierType != ModifierTypes.INTERNAL) + val modifiers = (modifiersForNode(constructorDecl) :+ newModifierNode(ModifierTypes.CONSTRUCTOR)) + .filter(_.modifierType != ModifierTypes.INTERNAL) val isStaticConstructor = modifiers.exists(_.modifierType == ModifierTypes.STATIC) @@ -331,7 +404,10 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { Seq(methodAst(methodNode_, thisNode +: params, body, methodReturn, modifiers)) } - protected def astForMethodDeclaration(methodDecl: DotNetNodeInfo): Seq[Ast] = { + protected def astForMethodDeclaration( + methodDecl: DotNetNodeInfo, + extraModifiers: List[NewModifier] = Nil + ): Seq[Ast] = { val name = nameFromNode(methodDecl) val params = methodDecl .json(ParserKeys.ParameterList) @@ -349,9 +425,8 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val methodReturnAstNode = createDotNetNodeInfo(methodDecl.json(ParserKeys.ReturnType)) val methodReturn = methodReturnNode(methodReturnAstNode, nodeTypeFullName(methodReturnAstNode)) - val signature = - methodSignature(methodReturn, params.flatMap(_.nodes.collectFirst { case x: NewMethodParameterIn => x })) - val fullName = s"${astFullName(methodDecl)}:$signature" + val signature = composeMethodLikeSignature(methodReturn.typeFullName, params) + val fullName = s"${astFullName(methodDecl)}:$signature" val methodNode_ = methodNode(methodDecl, name, code(methodDecl), fullName, Option(signature), relativeFileName) scope.pushNewScope(MethodScope(fullName)) @@ -361,17 +436,13 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { if (!jsonBody.isNull && parseLevel == AstParseLevel.FULL_AST) astForBlock(createDotNetNodeInfo(jsonBody)) else Ast(blockNode(methodDecl)) // Creates an empty block scope.popScope() - val modifiers = astForModifiers(methodDecl).flatMap(_.nodes).collect { case x: NewModifier => x } + val modifiers = modifiersForNode(methodDecl) ++ extraModifiers val thisNode = if (!modifiers.exists(_.modifierType == ModifierTypes.STATIC)) astForThisParameter(methodDecl) else Ast() Seq(methodAstWithAnnotations(methodNode_, thisNode +: params, body, methodReturn, modifiers, annotationAsts)) } - private def methodSignature(methodReturn: NewMethodReturn, params: Seq[NewMethodParameterIn]): String = { - s"${methodReturn.typeFullName}(${params.map(_.typeFullName).mkString(",")})" - } - private def astForParameter(paramNode: DotNetNodeInfo, idx: Int, paramTypeHint: Option[String] = None): Ast = { val name = nameFromNode(paramNode) val isVariadic = false // TODO @@ -384,14 +455,14 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForThisParameter(methodDecl: DotNetNodeInfo): Ast = { - val name = "this" + val name = Constants.This val typeFullName = scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) val param = parameterInNode(methodDecl, name, name, 0, false, EvaluationStrategies.BY_SHARING.name, typeFullName) Ast(param) } protected def astForThisReceiver(invocationExpr: DotNetNodeInfo, typeFullName: Option[String] = None): Ast = { - val name = "this" + val name = Constants.This val param = identifierNode( invocationExpr, name, @@ -423,10 +494,12 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { * https://learn.microsoft.com/en-us/dotnet/csharp/programming-guide/classes-and-structs/access-modifiers */ private def astForModifiers(declaration: DotNetNodeInfo): Seq[Ast] = { - val explicitModifiers = declaration.json(ParserKeys.Modifiers).arr.flatMap(astForModifier).toList - val accessModifiers = explicitModifiers - .flatMap(_.nodes) - .collect { case x: NewModifier => x.modifierType } intersect List( + modifiersForNode(declaration).map(Ast(_)) + } + + private def modifiersForNode(node: DotNetNodeInfo): Seq[NewModifier] = { + val explicitModifiers = node.json(ParserKeys.Modifiers).arr.flatMap(readModifier).toList + val accessModifiers = explicitModifiers.map(_.modifierType) intersect List( ModifierTypes.PUBLIC, ModifierTypes.PRIVATE, ModifierTypes.INTERNAL, @@ -435,28 +508,34 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { ) val implicitAccessModifier = accessModifiers match // Internal is default for top-level definitions - case Nil if scope.isTopLevel => Ast(newModifierNode(ModifierTypes.INTERNAL)) + case Nil if scope.isTopLevel => newModifierNode(ModifierTypes.INTERNAL) :: Nil // Private is default for nested definitions - case Nil => Ast(newModifierNode(ModifierTypes.PRIVATE)) - case _ => Ast() + case Nil => newModifierNode(ModifierTypes.PRIVATE) :: Nil + case _ => Nil - implicitAccessModifier :: explicitModifiers + implicitAccessModifier ++ explicitModifiers } - private def astForModifier(modifier: ujson.Value): Option[Ast] = { + private def readModifier(modifier: ujson.Value): Option[NewModifier] = { Option { modifier(ParserKeys.Value).str match - case "public" => newModifierNode(ModifierTypes.PUBLIC) - case "private" => newModifierNode(ModifierTypes.PRIVATE) - case "internal" => newModifierNode(ModifierTypes.INTERNAL) - case "static" => newModifierNode(ModifierTypes.STATIC) - case "readonly" => newModifierNode(ModifierTypes.READONLY) - case "virtual" => newModifierNode(ModifierTypes.VIRTUAL) - case "const" => newModifierNode(CSharpModifiers.CONST) + case "public" => newModifierNode(ModifierTypes.PUBLIC) + case "private" => newModifierNode(ModifierTypes.PRIVATE) + case "internal" => newModifierNode(ModifierTypes.INTERNAL) + case "static" => newModifierNode(ModifierTypes.STATIC) + case "readonly" => newModifierNode(ModifierTypes.READONLY) + case "virtual" => newModifierNode(ModifierTypes.VIRTUAL) + case "const" => newModifierNode(CSharpModifiers.CONST) + case "abstract" => newModifierNode(ModifierTypes.ABSTRACT) + case "protected" => newModifierNode(ModifierTypes.PROTECTED) case x => logger.warn(s"Unhandled modifier name '$x'") null - }.map(Ast(_)) + } + } + + private def astForModifier(modifier: ujson.Value): Option[Ast] = { + readModifier(modifier).map(Ast(_)) } protected def astVariableDeclarationForInitializedFields(fieldDecls: Seq[FieldDecl]): Seq[Ast] = { @@ -466,13 +545,52 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForPropertyDeclaration(propertyDecl: DotNetNodeInfo): Seq[Ast] = { - val propertyName = nameFromNode(propertyDecl) - val modifierAst = astForModifiers(propertyDecl) - val typeFullName = nodeTypeFullName(propertyDecl) + val accessorList = createDotNetNodeInfo(propertyDecl.json(ParserKeys.AccessorList)) + val accessors = accessorList.json(ParserKeys.Accessors).arr.map(createDotNetNodeInfo) + accessors.flatMap(astForPropertyAccessor(_, propertyDecl)).toList + } - val _memberNode = memberNode(propertyDecl, propertyName, propertyDecl.code, typeFullName) + private def astForPropertyAccessor(accessorDecl: DotNetNodeInfo, propertyDecl: DotNetNodeInfo): Seq[Ast] = { + accessorDecl.node match + case GetAccessorDeclaration => astForGetAccessorDeclaration(accessorDecl, propertyDecl) + case SetAccessorDeclaration => astForSetAccessorDeclaration(accessorDecl, propertyDecl) + case _ => + logger.warn(s"Unhandled property accessor '${accessorDecl.node}'") + Nil + } + + private def astForSetAccessorDeclaration(accessorDecl: DotNetNodeInfo, propertyDecl: DotNetNodeInfo): Seq[Ast] = { + val name = composeSetterName(nameFromNode(propertyDecl)) + val modifiers = modifiersForNode(propertyDecl) + val returnType = BuiltinTypes.Void + val valueType = nodeTypeFullName(propertyDecl) + val baseType = scope.surroundingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val isStatic = modifiers.exists(_.modifierType == ModifierTypes.STATIC) + val valueParam = Ast(NewMethodParameterIn().typeFullName(valueType).name("value").index(1)) + val parameters = Option.unless(isStatic)(astForThisParameter(propertyDecl)).toList :+ valueParam + val signature = composeMethodLikeSignature(returnType, parameters) + val fullName = composeMethodFullName(baseType, name, signature) + val body = Try(astForBlock(createDotNetNodeInfo(accessorDecl.json(ParserKeys.Body)))).getOrElse(Ast()) + val methodReturn = methodReturnNode(accessorDecl, returnType) + val methodNode_ = methodNode(accessorDecl, name, fullName, signature, relativeFileName) + + methodAst(methodNode_, parameters, body, methodReturn, modifiers) :: Nil + } - Seq(Ast(_memberNode).withChildren(modifierAst)) + private def astForGetAccessorDeclaration(accessorDecl: DotNetNodeInfo, propertyDecl: DotNetNodeInfo): Seq[Ast] = { + val name = composeGetterName(nameFromNode(propertyDecl)) + val modifiers = modifiersForNode(propertyDecl) + val returnType = nodeTypeFullName(propertyDecl) + val baseType = scope.surroundingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val isStatic = modifiers.exists(_.modifierType == ModifierTypes.STATIC) + val parameters = if isStatic then Nil else astForThisParameter(propertyDecl) :: Nil + val signature = composeMethodLikeSignature(returnType, parameters) + val fullName = composeMethodFullName(baseType, name, signature) + val body = Ast(blockNode(accessorDecl)) + val methodReturn = methodReturnNode(accessorDecl, returnType) + val methodNode_ = methodNode(accessorDecl, name, fullName, signature, relativeFileName) + + methodAst(methodNode_, parameters, body, methodReturn, modifiers) :: Nil } /** Creates an AST for a simple `x => { ... }` style lambda expression @@ -487,8 +605,12 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { paramTypeHint: Option[String] = None ): Seq[Ast] = { // Create method declaration - val name = nextClosureName() - val fullName = s"${scope.surroundingScopeFullName.getOrElse(Defines.UnresolvedNamespace)}.$name" + val name = nextClosureName() + val fullName = { + val baseType = withoutSignature(scope.surroundingScopeFullName.getOrElse(Defines.UnresolvedNamespace)) + val signature = Defines.UnresolvedSignature + composeMethodFullName(baseType, name, signature) + } // Set parameter type if necessary, which may require the type hint val paramType = paramTypeHint.flatMap(AstCreatorHelper.elementTypesFromCollectionType).headOption val paramAsts = Try(lambdaExpression.json(ParserKeys.Parameter)).toOption match { @@ -552,7 +674,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { def astForAnonymousObjectCreationExpression(anonObjExpr: DotNetNodeInfo): Seq[Ast] = { val typeDeclName = nextAnonymousTypeName() - val typeDeclFullName = s"${scope.surroundingScopeFullName.getOrElse(Defines.Any)}.${typeDeclName}" + val typeDeclFullName = s"${withoutSignature(scope.surroundingScopeFullName.getOrElse(Defines.Any))}.${typeDeclName}" val _typeDeclNode = typeDeclNode( anonObjExpr, diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index 5c22db90103e..c56d637a3294 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -1,17 +1,20 @@ package io.joern.csharpsrc2cpg.astcreation -import io.joern.csharpsrc2cpg.datastructures.CSharpMethod +import io.joern.csharpsrc2cpg.astcreation.AstParseLevel.FULL_AST +import io.joern.csharpsrc2cpg.astcreation.BuiltinTypes.DotNetTypeMap +import io.joern.csharpsrc2cpg.datastructures.{CSharpMethod, FieldDecl} import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} +import io.joern.csharpsrc2cpg.utils.Utils.{composeMethodFullName, composeMethodLikeSignature} import io.joern.csharpsrc2cpg.{CSharpOperators, Constants} import io.joern.x2cpg.utils.NodeBuilders.{newCallNode, newIdentifierNode, newOperatorCallNode} import io.joern.x2cpg.{Ast, Defines, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewLiteral, NewTypeRef} +import io.shiftleft.codepropertygraph.generated.nodes.{NewLiteral, NewTypeRef} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import ujson.Value import scala.collection.mutable.ArrayBuffer -import scala.util.{Failure, Success, Try} +import scala.util.Try trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => def astForExpressionStatement(expr: DotNetNodeInfo): Seq[Ast] = { @@ -38,10 +41,212 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case ConditionalAccessExpression => astForConditionalAccessExpression(expr) case SuppressNullableWarningExpression => astForSuppressNullableWarningExpression(expr) case _: BaseLambdaExpression => astForSimpleLambdaExpression(expr) + case ParenthesizedExpression => astForParenthesizedExpression(expr) case _ => notHandledYet(expr) } } + /** Attempts to decide if [[expr]] denotes a setter property reference, in which case returns its corresponding + * [[CSharpMethod]] meta-data and class full name it belongs to. + */ + private def tryResolveSetterInvocation(expr: DotNetNodeInfo): Option[(CSharpMethod, String)] = { + val baseType = expr.node match { + case SimpleMemberAccessExpression => + val base = createDotNetNodeInfo(expr.json(ParserKeys.Expression)) + Some(nodeTypeFullName(base)) + case IdentifierName => + scope.surroundingTypeDeclFullName + case _ => + None + } + + val fieldName = nameFromNode(expr) + baseType.flatMap(x => scope.tryResolveSetterInvocation(fieldName, Some(x)).map((_, x))) + } + + private def stripAssignmentFromOperator(operatorName: String): Option[String] = operatorName match { + case Operators.assignmentPlus => Some(Operators.plus) + case Operators.assignmentMinus => Some(Operators.minus) + case Operators.assignmentMultiplication => Some(Operators.multiplication) + case Operators.assignmentDivision => Some(Operators.division) + case Operators.assignmentExponentiation => Some(Operators.exponentiation) + case Operators.assignmentModulo => Some(Operators.modulo) + case Operators.assignmentShiftLeft => Some(Operators.shiftLeft) + case Operators.assignmentLogicalShiftRight => Some(Operators.logicalShiftRight) + case Operators.assignmentArithmeticShiftRight => Some(Operators.arithmeticShiftRight) + case Operators.assignmentAnd => Some(Operators.and) + case Operators.assignmentOr => Some(Operators.or) + case Operators.assignmentXor => Some(Operators.xor) + case _ => None + } + + /** Mainly to abstract the lowering of +=, *=, etc. assignments when the LHS is a property. Takes care of building the + * RHS appropriately, e.g. by expanding `P += RHS` into `set_P(get_P() + RHS)`, etc. + * @param expr + * the full assignment expression, for `code`, `line`, etc. + * @param assignOp + * the assignment operator, cf. [[Operators]] + * @param setterInfo + * the setter meta-data, cf. [[tryResolveSetterInvocation]] + */ + private def lowerSetterAssignmentRhs( + expr: DotNetNodeInfo, + assignOp: String, + setterInfo: (CSharpMethod, String), + receiver: Option[Ast], + rhs: DotNetNodeInfo + ): Seq[Ast] = { + val (setterMethod, setterBaseType) = setterInfo + val propertyName = setterMethod.name.stripPrefix("set_") + val originalRhs = astForOperand(rhs) + + assignOp match { + case Operators.assignment => originalRhs + case _ => + scope.tryResolveGetterInvocation(propertyName, Some(setterBaseType)) match { + // Shouldn't happen, provided it is valid code. At any rate, log and emit the RHS verbatim. + case None => + logger.warn(s"Couldn't find matching getter for $propertyName in ${code(expr)}") + originalRhs + case Some(getterMethod) => + stripAssignmentFromOperator(assignOp) match { + case None => + logger.warn(s"Unrecognized assignment in ${code(expr)}") + originalRhs + case Some(opName) => + val getterInvocation = + createInvocationAst(expr, getterMethod.name, Nil, receiver, Some(getterMethod), Some(setterBaseType)) + val operatorCall = + newOperatorCallNode(opName, code(expr), Some(setterMethod.returnType), line(expr), column(expr)) + callAst(operatorCall, getterInvocation +: originalRhs, None, None) :: Nil + } + } + } + } + + /** Lowers assignments such as `x.P = RHS` and `x.P += RHS` with `P` denoting a setter property into calls + * `x.set_P(RHS)` and `x.set_P(x.get_P() + RHS)`. + * @param assignExpr + * the full assignment expression, for `code`, `line` properties + * @param assignOp + * the final assignment operator name, cf. [[Operators]] + * @param setterInfo + * the setter meta-data, cf. [[tryResolveSetterInvocation]] + */ + private def astForMemberAccessSetterAssignment( + assignExpr: DotNetNodeInfo, + lhs: DotNetNodeInfo, + assignOp: String, + rhs: DotNetNodeInfo, + setterInfo: (CSharpMethod, String) + ): Seq[Ast] = { + val (setterMethod, setterBaseType) = setterInfo + val receiver = if (setterMethod.isStatic) { + None + } else { + val baseNode = createDotNetNodeInfo(lhs.json(ParserKeys.Expression)) + astForNode(baseNode).headOption + } + val rhsAst = lowerSetterAssignmentRhs(assignExpr, assignOp, setterInfo, receiver, rhs) + + createInvocationAst( + assignExpr, + setterMethod.name, + rhsAst, + receiver, + Some(setterMethod), + Some(setterBaseType) + ) :: Nil + } + + /** Lowers assignments such as `P = RHS` and `P += RHS` with `P` an identifier denoting a setter property into calls + * `set_P(RHS)` and `set_P(get_P() + RHS)`, respectively. + * @param assignExpr + * the full assignment expression, for `code`, `line` properties. + * @param assignOp + * the final assignment operator name, cf. [[Operators]] + * @param setterInfo + * the setter meta-data, cf. [[tryResolveSetterInvocation]] + */ + private def astForIdentifierSetterAssignment( + assignExpr: DotNetNodeInfo, + lhs: DotNetNodeInfo, + assignOp: String, + rhs: DotNetNodeInfo, + setterInfo: (CSharpMethod, String) + ): Seq[Ast] = { + val (setterMethod, setterBaseType) = setterInfo + val receiver = Option.when(!setterMethod.isStatic)(astForThisReceiver(lhs, scope.surroundingTypeDeclFullName)) + val rhsAst = lowerSetterAssignmentRhs(assignExpr, assignOp, setterInfo, receiver, rhs) + + createInvocationAst( + assignExpr, + setterMethod.name, + rhsAst, + receiver, + Some(setterMethod), + Some(setterBaseType) + ) :: Nil + } + + /** Lowers assignments such as `x.P = RHS` and `P += RHS` where `P` denotes a setter property into a call + * `x.set_P(RHS)` and `set_P(get_P() + RHS)`, respectively. + * + * @param assignExpr + * the full assignment expr, for `code`, `line` properties + * @param setterInfo + * the setter meta-data, cf. [[tryResolveSetterInvocation]] + * @param assignOp + * the final assignment operator name, cf. [[Operators]] + */ + private def astForSetterAssignmentExpression( + assignExpr: DotNetNodeInfo, + setterInfo: (CSharpMethod, String), + lhs: DotNetNodeInfo, + assignOp: String, + rhs: DotNetNodeInfo + ): Seq[Ast] = { + lhs.node match { + case SimpleMemberAccessExpression => + astForMemberAccessSetterAssignment(assignExpr, lhs, assignOp, rhs, setterInfo) + case IdentifierName => astForIdentifierSetterAssignment(assignExpr, lhs, assignOp, rhs, setterInfo) + case _ => + logger.warn(s"Unsupported setter assignment: ${code(assignExpr)}") + Nil + } + } + + /** Lowers arbitrary assignment `LHS = RHS`, `LHS += RHS`, etc. expressions. + * @param assignExpr + * the full assignment, for `code`, `line` properties + * @param assignOp + * the final assignment operator name, cf. [[Operators]] + */ + private def astForAssignmentExpression( + assignExpr: DotNetNodeInfo, + lhs: DotNetNodeInfo, + assignOp: String, + rhs: DotNetNodeInfo + ): Seq[Ast] = { + tryResolveSetterInvocation(lhs) match { + case Some(setterInfo) => astForSetterAssignmentExpression(assignExpr, setterInfo, lhs, assignOp, rhs) + case None => astForRegularAssignmentExpression(assignExpr, lhs, assignOp, rhs) + } + } + + private def astForRegularAssignmentExpression( + assignExpr: DotNetNodeInfo, + lhs: DotNetNodeInfo, + assignOp: String, + rhs: DotNetNodeInfo + ): Seq[Ast] = { + astForRegularBinaryExpression(assignExpr, lhs, assignOp, rhs) + } + + private def astForParenthesizedExpression(parenExpr: DotNetNodeInfo): Seq[Ast] = { + astForNode(parenExpr.json(ParserKeys.Expression)) + } + private def astForAwaitExpression(awaitExpr: DotNetNodeInfo): Seq[Ast] = { /* fullName is the name in case of STATIC_DISPATCH */ val node = @@ -61,14 +266,34 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForOperand(operandNode: DotNetNodeInfo): Seq[Ast] = { operandNode.node match { case IdentifierName => - List(scope.findFieldInScope(nameFromNode(operandNode)), scope.lookupVariable(nameFromNode(operandNode))) match { - case List(Some(_), None) => astForSimpleMemberAccessExpression(operandNode) + (scope.findFieldInScope(nameFromNode(operandNode)), scope.lookupVariable(nameFromNode(operandNode))) match { + case (Some(field), None) => createImplicitBaseFieldAccess(operandNode, field) case _ => astForNode(operandNode) } case _ => astForNode(operandNode) } } + private def createImplicitBaseFieldAccess(fieldNode: DotNetNodeInfo, field: FieldDecl): Seq[Ast] = { + // TODO: Maybe this should be a TypeRef, like we recently started doing for javasrc? + val baseNode = if (field.isStatic) { + newIdentifierNode(scope.surroundingTypeDeclFullName.getOrElse(Defines.Any), field.typeFullName) + } else { + newIdentifierNode(Constants.This, field.typeFullName) + } + + fieldAccessAst( + base = Ast(baseNode), + code = s"${baseNode.code}.${field.name}", + lineNo = fieldNode.lineNumber, + columnNo = fieldNode.columnNumber, + fieldName = field.name, + fieldTypeFullName = field.typeFullName, + fieldLineNo = fieldNode.lineNumber, + fieldColumnNo = fieldNode.columnNumber + ) :: Nil + } + protected def astForUnaryExpression(unaryExpr: DotNetNodeInfo): Seq[Ast] = { val operatorToken = unaryExpr.json(ParserKeys.OperatorToken)(ParserKeys.Value).str val operatorName = operatorToken match @@ -84,59 +309,44 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case "!" => Operators.logicalNot case "&" => Operators.addressOf - val args = createDotNetNodeInfo(unaryExpr.json(ParserKeys.Operand)) - val argsAst = astForOperand(args) + val args = createDotNetNodeInfo(unaryExpr.json(ParserKeys.Operand)) + val argsAst = astForOperand(args) + val callNode = operatorCallNode(unaryExpr, operatorName, Some(nodeTypeFullName(args))) - Seq( - callAst(createCallNodeForOperator(unaryExpr, operatorName, typeFullName = Some(nodeTypeFullName(args))), argsAst) - ) + callAst(callNode, argsAst) :: Nil } - protected def astForBinaryExpression(binaryExpr: DotNetNodeInfo): Seq[Ast] = { + private def astForBinaryExpression(binaryExpr: DotNetNodeInfo): Seq[Ast] = { + val lhsNode = createDotNetNodeInfo(binaryExpr.json(ParserKeys.Left)) + val rhsNode = createDotNetNodeInfo(binaryExpr.json(ParserKeys.Right)) val operatorToken = binaryExpr.json(ParserKeys.OperatorToken)(ParserKeys.Value).str - val operatorName = operatorToken match - case "+" => Operators.addition - case "-" => Operators.subtraction - case "*" => Operators.multiplication - case "/" => Operators.division - case "%" => Operators.modulo - case "==" => Operators.equals - case "!=" => Operators.notEquals - case "&&" => Operators.logicalAnd - case "||" => Operators.logicalOr - case "=" => Operators.assignment - case "+=" => Operators.assignmentPlus - case "-=" => Operators.assignmentMinus - case "*=" => Operators.assignmentMultiplication - case "/=" => Operators.assignmentDivision - case "%=" => Operators.assignmentModulo - case "&=" => Operators.assignmentAnd - case "|=" => Operators.assignmentOr - case "^=" => Operators.assignmentXor - case ">>=" => Operators.assignmentLogicalShiftRight - case "<<=" => Operators.assignmentShiftLeft - case ">" => Operators.greaterThan - case "<" => Operators.lessThan - case ">=" => Operators.greaterEqualsThan - case "<=" => Operators.lessEqualsThan - case "|" => Operators.or - case "&" => Operators.and - case "^" => Operators.xor - case x => - logger.warn(s"Unhandled operator '$x' for ${code(binaryExpr)}") + val operatorName = binaryOperatorsMap.getOrElse( + operatorToken, { + logger.warn(s"Unhandled operator '$operatorToken' for ${code(binaryExpr)}") CSharpOperators.unknown - - val args = astForOperand(createDotNetNodeInfo(binaryExpr.json(ParserKeys.Left))) ++: astForOperand( - createDotNetNodeInfo(binaryExpr.json(ParserKeys.Right)) + } ) + binaryExpr.node match { + case _: AssignmentExpr => astForAssignmentExpression(binaryExpr, lhsNode, operatorName, rhsNode) + case _ => astForRegularBinaryExpression(binaryExpr, lhsNode, operatorName, rhsNode) + } + } - val cNode = - createCallNodeForOperator( - binaryExpr, - operatorName, - typeFullName = Some(fixedTypeOperators.getOrElse(operatorName, getTypeFullNameFromAstNode(args))) - ) - Seq(callAst(cNode, args)) + /** @param binaryExpr + * the full binary expression, for `code`, `line`, etc. + * @param operatorName + * the final operator name, cf. [[Operators]] + */ + private def astForRegularBinaryExpression( + binaryExpr: DotNetNodeInfo, + lhs: DotNetNodeInfo, + operatorName: String, + rhs: DotNetNodeInfo + ): Seq[Ast] = { + val args = astForOperand(lhs) ++: astForOperand(rhs) + val typeFullName = fixedTypeOperators.get(operatorName).orElse(Some(getTypeFullNameFromAstNode(args))) + val callNode = operatorCallNode(binaryExpr, operatorName, typeFullName) + callAst(callNode, args) :: Nil } /** Handles the `= ...` part of the equals value clause, thus this only contains an RHS. @@ -202,70 +412,24 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } - private def astForInvocationExpression(invocationExpr: DotNetNodeInfo): Seq[Ast] = { - val expression = createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)) - val callName = nameFromNode(expression) - val argumentList = createDotNetNodeInfo(invocationExpr.json(ParserKeys.ArgumentList)) - - val ( - receiver: Option[Ast], - baseTypeFullName: Option[String], - methodMetaData: Option[CSharpMethod], - arguments: Seq[Ast] - ) = expression.node match { - case SimpleMemberAccessExpression | SuppressNullableWarningExpression => - val baseNode = createDotNetNodeInfo( - createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)).json(ParserKeys.Expression) - ) - val receiverAst = astForNode(baseNode).toList - val baseTypeFullName = receiverAst match { - case head :: _ => Option(getTypeFullNameFromAstNode(head)).filterNot(_ == Defines.Any) - case _ => None - } - val arguments = astForArgumentList(argumentList, baseTypeFullName) - val argTypes = arguments.map(getTypeFullNameFromAstNode).toList - val methodMetaData = scope.tryResolveMethodInvocation(callName, argTypes, baseTypeFullName) - (receiverAst.headOption, baseTypeFullName, methodMetaData, arguments) - case IdentifierName | MemberBindingExpression => - // This is when a call is made directly, which could also be made from a static import - val argTypes = astForArgumentList(argumentList).map(getTypeFullNameFromAstNode).toList - scope - .tryResolveMethodInvocation(callName, argTypes) - .orElse(scope.tryResolveMethodInvocation(callName, argTypes, scope.surroundingTypeDeclFullName)) match { - case Some(methodMetaData) if methodMetaData.isStatic => - // If static, create implicit type identifier explicitly - val typeMetaData = scope.typeForMethod(methodMetaData) - val typeName = typeMetaData.flatMap(_.name.split("[.]").lastOption).getOrElse(Defines.Any) - val typeFullName = typeMetaData.map(_.name) - val receiverNode = Ast( - identifierNode(invocationExpr, typeName, typeName, typeFullName.getOrElse(Defines.Any)) - ) - val arguments = astForArgumentList(argumentList, typeFullName) - (Option(receiverNode), typeFullName, Option(methodMetaData), arguments) - case Some(methodMetaData) => - // If dynamic, create implicit `this` identifier explicitly - val typeMetaData = scope.typeForMethod(methodMetaData) - val typeFullName = typeMetaData.map(_.name) - val thisAst = astForThisReceiver(invocationExpr, typeFullName) - val arguments = astForArgumentList(argumentList, typeFullName) - (Option(thisAst), typeMetaData.map(_.name), Option(methodMetaData), arguments) - case None => - (None, None, None, Seq.empty[Ast]) - } - case x => - logger.warn(s"Unhandled LHS $x for InvocationExpression") - (None, None, None, Seq.empty[Ast]) - } + private def createInvocationAst( + invocationExpr: DotNetNodeInfo, + callName: String, + arguments: Seq[Ast], + baseAst: Option[Ast], + methodMetaData: Option[CSharpMethod], + baseTypeFullName: Option[String] + ): Ast = { val methodSignature = methodMetaData match { - case Some(m) => s"${m.returnType}(${m.parameterTypes.filterNot(_._1 == "this").map(_._2).mkString(",")})" - case None => Defines.UnresolvedSignature + case Some(m) => + val returnType = DotNetTypeMap.getOrElse(m.returnType, m.returnType) + composeMethodLikeSignature(returnType, m.parameterTypes.filterNot(_._1 == Constants.This).map(_._2)) + case None => Defines.UnresolvedSignature } val methodFullName = baseTypeFullName match { - case Some(typeFullName) => - s"$typeFullName.$callName:$methodSignature" - case _ => - s"${Defines.UnresolvedNamespace}.$callName:$methodSignature" + case Some(typeFullName) => composeMethodFullName(typeFullName, callName, methodSignature) + case _ => composeMethodFullName(Defines.UnresolvedNamespace, callName, methodSignature) } val dispatchType = methodMetaData .map(_.isStatic) @@ -286,74 +450,162 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { methodMetaData.map(_.returnType) ), arguments, - receiver + baseAst ) - Seq(_callAst) + + _callAst } - protected def astForSimpleMemberAccessExpression(accessExpr: DotNetNodeInfo): Seq[Ast] = { - val fieldIdentifierName = nameFromNode(accessExpr) + /** Handles expressions like `foo.Bar()`. If `Bar` can't be found inside `foo`'s class, attempts to find a compatible + * extension method. If all fails, an AST is still produced. + */ + private def astForMemberAccessInvocation( + invocationExpr: DotNetNodeInfo, + baseAst: Option[Ast], + argumentList: DotNetNodeInfo, + callName: String + ): Seq[Ast] = { - val (identifierName, typeFullName) = accessExpr.node match { - case SimpleMemberAccessExpression => { - createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)).node match - case SuppressNullableWarningExpression => - val baseNode = createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)(ParserKeys.Operand)) - val baseAst = astForNode(baseNode) - val baseTypeFullName = getTypeFullNameFromAstNode(baseAst) - - val fieldInScope = scope.tryResolveFieldAccess(fieldIdentifierName, typeFullName = Option(baseTypeFullName)) - - ( - nameFromNode(baseNode), - fieldInScope - .map(_.typeName) - .getOrElse(Defines.Any) - ) - case _ => { - val fieldInScope = scope.findFieldInScope(fieldIdentifierName) - val _identifierName = - if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true)) - scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) - else Constants.This - val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any) - (_identifierName, _typeFullName) - } - } - case _ => { - val fieldInScope = scope.findFieldInScope(fieldIdentifierName) - val _identifierName = - if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true)) - scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) - else Constants.This - val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any) - (_identifierName, _typeFullName) - } + val baseTypeFullName = Some(getTypeFullNameFromAstNode(baseAst.toList)).filterNot(_ == Defines.Any) + val arguments = astForArgumentList(argumentList, baseTypeFullName) + val argTypes = arguments.map(getTypeFullNameFromAstNode).toList + + val byMethod = scope.tryResolveMethodInvocation(callName, argTypes, baseTypeFullName) + lazy val byExtMethod = scope.tryResolveExtensionMethodInvocation(baseTypeFullName, callName, argTypes) + + val (method, baseType) = byMethod + .map(x => (Some(x), baseTypeFullName)) + .orElse(byExtMethod.map(x => (Some(x._1), Some(x._2)))) + .getOrElse((None, baseTypeFullName)) + + createInvocationAst(invocationExpr, callName, arguments, baseAst, method, baseType) :: Nil + } + + private def astForIdentifierInvocation( + invocationExpr: DotNetNodeInfo, + argumentList: DotNetNodeInfo, + callName: String + ): Seq[Ast] = { + // This is when a call is made directly, which could also be made from a static import + val argTypes = astForArgumentList(argumentList).map(getTypeFullNameFromAstNode).toList + val (receiver, baseType, method, args) = scope + .tryResolveMethodInvocation(callName, argTypes) + .orElse(scope.tryResolveMethodInvocation(callName, argTypes, scope.surroundingTypeDeclFullName)) match { + case Some(methodMetaData) if methodMetaData.isStatic => + // If static, create implicit type identifier explicitly + val typeMetaData = scope.typeForMethod(methodMetaData) + val typeName = typeMetaData.flatMap(_.name.split("[.]").lastOption).getOrElse(Defines.Any) + val typeFullName = typeMetaData.map(_.name) + val receiverNode = Ast(identifierNode(invocationExpr, typeName, typeName, typeFullName.getOrElse(Defines.Any))) + val arguments = astForArgumentList(argumentList, typeFullName) + (Option(receiverNode), typeFullName, Option(methodMetaData), arguments) + case Some(methodMetaData) => + // If dynamic, create implicit `this` identifier explicitly + val typeMetaData = scope.typeForMethod(methodMetaData) + val typeFullName = typeMetaData.map(_.name) + val thisAst = astForThisReceiver(invocationExpr, typeFullName) + val arguments = astForArgumentList(argumentList, typeFullName) + (Option(thisAst), typeMetaData.map(_.name), Option(methodMetaData), arguments) + case None => + (None, None, None, Seq.empty[Ast]) } - val identifier = newIdentifierNode(identifierName, typeFullName) + createInvocationAst(invocationExpr, callName, args, receiver, method, baseType) :: Nil + } - val fieldIdentifier = NewFieldIdentifier() - .code(fieldIdentifierName) - .canonicalName(fieldIdentifierName) - .lineNumber(accessExpr.lineNumber) - .columnNumber(accessExpr.columnNumber) + private def astForInvocationExpression(invocationExpr: DotNetNodeInfo): Seq[Ast] = { + val expression = createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)) + val callName = nameFromNode(expression) + val argumentList = createDotNetNodeInfo(invocationExpr.json(ParserKeys.ArgumentList)) - val fieldAccessCode = s"$identifierName.$fieldIdentifierName" + expression.node match { + case SimpleMemberAccessExpression | SuppressNullableWarningExpression => + val baseAst = astForNode(createDotNetNodeInfo(expression.json(ParserKeys.Expression))) + astForMemberAccessInvocation(invocationExpr, baseAst.headOption, argumentList, callName) + case IdentifierName | MemberBindingExpression => + astForIdentifierInvocation(invocationExpr, argumentList, callName) + case x => + logger.warn(s"Unhandled LHS $x for InvocationExpression") + Nil + } + } - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - fieldAccessCode, - Some(typeFullName).orElse(Some(Defines.Any)), - accessExpr.lineNumber, - accessExpr.columnNumber - ) + /** Handles expressions like `foo.MyField`, where `MyField` is known to be a getter property. Getters are lowered into + * calls, e.g. (a) System.Console.Out becomes System.Console.get_Out(), because it's a static method; (b) x.KeyChar + * becomes System.ConsoleKeyInfo.get_KeyChar(x), because it's a dynamic method. + */ + private def astForMemberAccessGetterExpression( + getter: CSharpMethod, + baseAst: Ast, + baseTypeFullName: String, + accessExpr: DotNetNodeInfo + ): Seq[Ast] = { + if (getter.isStatic) { + callAst( + newCallNode( + getter.name, + Some(baseTypeFullName), + getter.returnType, + DispatchTypes.STATIC_DISPATCH, + Nil, + code(accessExpr), + line(accessExpr), + column(accessExpr) + ) + ) :: Nil + } else { + callAst( + newCallNode( + getter.name, + Some(baseTypeFullName), + getter.returnType, + DispatchTypes.DYNAMIC_DISPATCH, + baseTypeFullName :: Nil, + code(accessExpr), + line(accessExpr), + column(accessExpr) + ), + base = Some(baseAst) + ) :: Nil + } + } + + private def astForSimpleMemberAccessExpression(accessExpr: DotNetNodeInfo): Seq[Ast] = { + val fieldIdentifierName = nameFromNode(accessExpr) + val baseAst = astForNode(createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression))).head + val baseTypeFullName = getTypeFullNameFromAstNode(baseAst) + + // The typical field access resolving mechanism + lazy val byFieldAccess = scope.tryResolveFieldAccess(fieldIdentifierName, Some(baseTypeFullName)) + + // Getters look like fields, but are underneath `get_`-prefixed methods + lazy val byPropertyName = scope.tryResolveGetterInvocation(fieldIdentifierName, Some(baseTypeFullName)) + + // accessExpr might be a qualified name e.g. `System.Console`, in which case `System` (baseAst) is not a type + // but a namespace. In this scenario, we look up the entire expression + lazy val byQualifiedName = scope.tryResolveTypeReference(accessExpr.code) - val identifierAst = Ast(identifier) - val fieldIdentAst = Ast(fieldIdentifier) + val (typeFullName, isGetter) = byFieldAccess + .map(x => (x.typeName, false)) + .orElse(byPropertyName.map(x => (x.returnType, true))) + .orElse(byQualifiedName.map(x => (x.name, false))) + .map((typeName, isGetter) => (scope.tryResolveTypeReference(typeName).map(_.name).getOrElse(typeName), isGetter)) + .getOrElse((Defines.Any, false)) - Seq(callAst(fieldAccess, Seq(identifierAst, fieldIdentAst))) + if (isGetter) { + astForMemberAccessGetterExpression(byPropertyName.get, baseAst, baseTypeFullName, accessExpr) + } else { + fieldAccessAst( + baseAst, + code(accessExpr), + line(accessExpr), + column(accessExpr), + fieldIdentifierName, + typeFullName, + line(accessExpr), + column(accessExpr) + ) :: Nil + } } protected def astForElementAccessExpression(elementAccessExpression: DotNetNodeInfo): Seq[Ast] = { @@ -521,21 +773,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { baseType: Option[String] = None ): Seq[Ast] = { val baseNode = createDotNetNodeInfo(condAccExpr.json(ParserKeys.Expression)) - val baseAst = astForNode(baseNode) - val baseTypeFullName = - if (getTypeFullNameFromAstNode(baseAst).equals(Defines.Any)) baseType - else Option(getTypeFullNameFromAstNode(baseAst)) + baseType.orElse(Some(getTypeFullNameFromAstNode(astForNode(baseNode)))).filterNot(_.equals(Defines.Any)) Try(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull))).toOption match { case Some(node) => node.node match { - case ConditionalAccessExpression => - astForConditionalAccessExpression(node, baseTypeFullName) - case MemberBindingExpression => astForMemberBindingExpression(node, baseTypeFullName) - case InvocationExpression => - astForInvocationExpression(node) - case _ => astForNode(node) + case ConditionalAccessExpression => astForConditionalAccessExpression(node, baseTypeFullName) + case MemberBindingExpression => astForMemberBindingExpression(node, baseTypeFullName) + case _ => astForNode(node) } case None => Seq.empty[Ast] } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForPrimitivesCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForPrimitivesCreator.scala index 13c3ab0656b4..b194e6adf16a 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForPrimitivesCreator.scala @@ -23,7 +23,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t case Some(field) if field.node.node != DotNetJsonAst.VariableDeclarator => astForFieldIdentifier(typeFullName, identifierName, field) case Some(field) => - Ast(identifierNode(ident, identifierName, ident.code, field.typeFullName)) + Ast(identifierNode(ident, identifierName, identifierName, field.typeFullName)) case None => // Check for static type reference scope.tryResolveTypeReference(identifierName) match { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala index cf2de044a993..3319e267bf81 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala @@ -1,18 +1,15 @@ package io.joern.csharpsrc2cpg.astcreation import io.joern.csharpsrc2cpg.CSharpOperators -import io.joern.csharpsrc2cpg.parser.DotNetNodeInfo -import io.joern.csharpsrc2cpg.parser.ParserKeys import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* -import io.joern.x2cpg.Ast -import io.joern.x2cpg.ValidationMode -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.nodes.NewControlStructure -import io.shiftleft.codepropertygraph.generated.nodes.NewIdentifier +import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} +import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.utils.NodeBuilders.newModifierNode +import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewIdentifier, NewLiteral, NewLocal} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} import scala.:: -import scala.util.Try +import scala.util.{Try, Success, Failure} trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -63,22 +60,27 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t protected def astForStatement(nodeInfo: DotNetNodeInfo): Seq[Ast] = { nodeInfo.node match { - case ExpressionStatement => astForExpressionStatement(nodeInfo) - case GlobalStatement => astForGlobalStatement(nodeInfo) - case IfStatement => astForIfStatement(nodeInfo) - case ThrowStatement => astForThrowStatement(nodeInfo) - case TryStatement => astForTryStatement(nodeInfo) - case ForEachStatement => astForForEachStatement(nodeInfo) - case ForStatement => astForForStatement(nodeInfo) - case DoStatement => astForDoStatement(nodeInfo) - case WhileStatement => astForWhileStatement(nodeInfo) - case SwitchStatement => astForSwitchStatement(nodeInfo) - case UsingStatement => astForUsingStatement(nodeInfo) - case _: JumpStatement => astForJumpStatement(nodeInfo) - case _ => notHandledYet(nodeInfo) + case ExpressionStatement => astForExpressionStatement(nodeInfo) + case GlobalStatement => astForGlobalStatement(nodeInfo) + case IfStatement => astForIfStatement(nodeInfo) + case ThrowStatement => astForThrowStatement(nodeInfo) + case TryStatement => astForTryStatement(nodeInfo) + case ForEachStatement => astForForEachStatement(nodeInfo) + case ForStatement => astForForStatement(nodeInfo) + case DoStatement => astForDoStatement(nodeInfo) + case WhileStatement => astForWhileStatement(nodeInfo) + case SwitchStatement => astForSwitchStatement(nodeInfo) + case UsingStatement => astForUsingStatement(nodeInfo) + case LocalFunctionStatement => astForLocalFunctionStatement(nodeInfo) + case _: JumpStatement => astForJumpStatement(nodeInfo) + case _ => notHandledYet(nodeInfo) } } + private def astForLocalFunctionStatement(nodeInfo: DotNetNodeInfo): Seq[Ast] = { + astForMethodDeclaration(nodeInfo) + } + private def astForSwitchLabel(labelNode: DotNetNodeInfo): Seq[Ast] = { val caseNode = jumpTargetNode(labelNode, "case", labelNode.code) labelNode.node match @@ -166,22 +168,100 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t } private def astForForEachStatement(forEachStmt: DotNetNodeInfo): Seq[Ast] = { - val forEachNode = controlStructureNode(forEachStmt, ControlStructureTypes.FOR, forEachStmt.code) - val iterableAst = astForNode(forEachStmt.json(ParserKeys.Expression)) - val forEachBlockAst = astForBlock(createDotNetNodeInfo(forEachStmt.json(ParserKeys.Statement))) - - val identifierValue = forEachStmt.json(ParserKeys.Identifier)(ParserKeys.Value).str - val _identifierNode = + val int32Tfn = BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + val forEachNode = controlStructureNode(forEachStmt, ControlStructureTypes.FOR, forEachStmt.code) + // Create the collection AST + def newCollectionAst = astForNode(forEachStmt.json(ParserKeys.Expression)) + val collectionNode = createDotNetNodeInfo(forEachStmt.json(ParserKeys.Expression)) + val collectionCode = code(collectionNode) + // Create the iterator variable + val iterName = forEachStmt.json(ParserKeys.Identifier)(ParserKeys.Value).str + val iterNode = forEachStmt.json(ParserKeys.Type) + val iterNodeTfn = nodeTypeFullName(createDotNetNodeInfo(iterNode)) + val iterIdentifier = identifierNode( - node = createDotNetNodeInfo(forEachStmt.json(ParserKeys.Type)), - name = identifierValue, - code = identifierValue, - typeFullName = nodeTypeFullName(createDotNetNodeInfo(forEachStmt.json(ParserKeys.Type))) + node = createDotNetNodeInfo(iterNode), + name = iterName, + code = iterName, + typeFullName = iterNodeTfn + ) + val iterVarLocal = NewLocal().name(iterName).code(iterName).typeFullName(iterNodeTfn) + scope.addToScope(iterName, iterVarLocal) + // Create a de-sugared `idx` variable, i.e., var _idx_ = 0 + val idxName = "_idx_" + val idxLocal = NewLocal().name(idxName).code(idxName).typeFullName(int32Tfn) + val idxIdenAtAssign = identifierNode(node = collectionNode, name = idxName, code = idxName, typeFullName = int32Tfn) + val idxAssignment = + callNode(forEachStmt, s"$idxName = 0", Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH) + val idxAssigmentArgs = + List(Ast(idxIdenAtAssign), Ast(NewLiteral().code("0").typeFullName(BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)))) + val idxAssignmentAst = callAst(idxAssignment, idxAssigmentArgs) + // Create condition based on `idx` variable, i.e., _idx_ < $collection.Count + val idxIdAtCond = idxIdenAtAssign.copy + val collectCountAccess = callNode( + forEachStmt, + s"$collectionCode.Count", + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH + ) + val fieldAccessAst = + callAst(collectCountAccess, newCollectionAst :+ Ast(NewFieldIdentifier().canonicalName("Count").code("Count"))) + val idxLt = + callNode( + forEachStmt, + s"$idxName < $collectionCode.Count", + Operators.lessThan, + Operators.lessThan, + DispatchTypes.STATIC_DISPATCH ) + val idxLtArgs = + List(Ast(idxIdAtCond), fieldAccessAst) + val ltCallCond = callAst(idxLt, idxLtArgs) + // Create the assignment from $element = $collection[_idx_++] + val idxIdAtCollAccess = idxIdenAtAssign.copy + val collectIdxAccess = callNode( + forEachStmt, + s"$collectionCode[$idxName++]", + Operators.indexAccess, + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH + ) + val postIncrAst = callAst( + callNode( + forEachStmt, + s"$idxName++", + Operators.postIncrement, + Operators.postIncrement, + DispatchTypes.STATIC_DISPATCH + ), + Ast(idxIdAtCollAccess) :: Nil + ) + val indexAccessAst = callAst(collectIdxAccess, newCollectionAst :+ postIncrAst) + val iteratorAssignmentNode = + callNode( + forEachStmt, + s"$iterName = $collectionCode[$idxName++]", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val iteratorAssignmentArgs = List(Ast(iterIdentifier), indexAccessAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - val iteratorVarAst = Ast(_identifierNode) + val forEachBlockAst = astForBlock(createDotNetNodeInfo(forEachStmt.json(ParserKeys.Statement))) - Seq(Ast(forEachNode).withChild(iteratorVarAst).withChildren(iterableAst).withChild(forEachBlockAst)) + forAst( + forNode = forEachNode, + locals = Ast(idxLocal) + .withRefEdge(idxIdenAtAssign, idxLocal) + .withRefEdge(idxIdAtCond, idxLocal) + .withRefEdge(idxIdAtCollAccess, idxLocal) :: Ast(iterVarLocal).withRefEdge(iterIdentifier, iterVarLocal) :: Nil, + conditionAsts = ltCallCond :: Nil, + initAsts = idxAssignmentAst :: Nil, + updateAsts = iteratorAssignmentAst :: Nil, + bodyAst = forEachBlockAst + ) :: Nil } private def astForElseStatement(elseParserNode: DotNetNodeInfo): Ast = { @@ -197,8 +277,14 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t } - protected def astForGlobalStatement(globalStatement: DotNetNodeInfo): Seq[Ast] = { - astForNode(globalStatement.json(ParserKeys.Statement)) + private def astForGlobalStatement(globalStatement: DotNetNodeInfo): Seq[Ast] = { + val stmtNodeInfo = createDotNetNodeInfo(globalStatement.json(ParserKeys.Statement)) + stmtNodeInfo.node match + // Denotes a top-level method declaration. These shall be added to the fictitious "main" created + // by `astForTopLevelStatements`. + case LocalFunctionStatement => + astForMethodDeclaration(stmtNodeInfo, extraModifiers = newModifierNode(ModifierTypes.STATIC) :: Nil) + case _ => astForNode(stmtNodeInfo) } private def astForJumpStatement(jumpStmt: DotNetNodeInfo): Seq[Ast] = { @@ -234,11 +320,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case Some(_expr: ujson.Obj) => astForNode(createDotNetNodeInfo(_expr)) case _ => Seq.empty[Ast] } - val throwCall = createCallNodeForOperator( - throwStmt, - CSharpOperators.throws, - typeFullName = Option(getTypeFullNameFromAstNode(argsAst)) - ) + val throwCall = operatorCallNode(throwStmt, CSharpOperators.throws, Some(getTypeFullNameFromAstNode(argsAst))) Seq(callAst(throwCall, argsAst)) } @@ -280,11 +362,11 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t * Thus, this is lowered as a try-finally, with finally making a call to `Dispose` on the declared variable. */ private def astForUsingStatement(usingStmt: DotNetNodeInfo): Seq[Ast] = { - val tryNode = controlStructureNode(usingStmt, ControlStructureTypes.TRY, code(usingStmt)) + val tryNode = controlStructureNode(usingStmt, ControlStructureTypes.TRY, code(usingStmt)) + val declAst = + Try(createDotNetNodeInfo(usingStmt.json(ParserKeys.Declaration))).map(astForNode).getOrElse(scala.Seq.empty[Ast]) val tryNodeInfo = createDotNetNodeInfo(usingStmt.json(ParserKeys.Statement)) val tryAst = astForBlock(tryNodeInfo, Option("try")) - val declNode = createDotNetNodeInfo(usingStmt.json(ParserKeys.Declaration)) - val declAst = astForNode(declNode) val finallyAst = declAst.flatMap(_.nodes).collectFirst { case x: NewIdentifier => x.copy }.map { id => val callCode = s"${id.name}.Dispose()" diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstSummaryVisitor.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstSummaryVisitor.scala index 9b1ba7c73646..5bf2d14328b1 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstSummaryVisitor.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstSummaryVisitor.scala @@ -1,5 +1,6 @@ package io.joern.csharpsrc2cpg.astcreation +import flatgraph.DiffGraphApplier.applyDiff import io.joern.csharpsrc2cpg.Constants import io.joern.csharpsrc2cpg.datastructures.{ CSharpField, @@ -12,9 +13,8 @@ import io.joern.csharpsrc2cpg.datastructures.{ import io.joern.csharpsrc2cpg.parser.ParserKeys import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder, EdgeTypes} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes} import io.shiftleft.semanticcpg.language.* -import overflowdb.{BatchedUpdate, Config} import scala.collection.mutable import scala.util.Using @@ -29,11 +29,11 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A this.parseLevel = AstParseLevel.SIGNATURES val fileNode = NewFile().name(relativeFileName) val compilationUnit = createDotNetNodeInfo(parserResult.json(ParserKeys.AstRoot)) - Using.resource(Cpg.withConfig(Config.withoutOverflow())) { cpg => + Using.resource(Cpg.empty) { cpg => // Build and store compilation unit AST val ast = Ast(fileNode).withChildren(astForCompilationUnit(compilationUnit)) Ast.storeInDiffGraph(ast, diffGraph) - BatchedUpdate.applyDiff(cpg.graph, diffGraph) + applyDiff(cpg.graph, diffGraph) // Simulate AST Linker for global namespace val globalNode = NewNamespaceBlock().fullName(Constants.Global).name(Constants.Global) @@ -41,7 +41,7 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A cpg.typeDecl .where(_.astParentFullNameExact(Constants.Global)) .foreach(globalDiffGraph.addEdge(globalNode, _, EdgeTypes.AST)) - BatchedUpdate.applyDiff(cpg.graph, globalDiffGraph) + applyDiff(cpg.graph, globalDiffGraph) // Summarize findings summarize(cpg) @@ -58,6 +58,8 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A def imports = cpg.imports.importedEntity.toSet + def globalImports = cpg.imports.filter(_.code.startsWith("global")).importedEntity.toSet + def toMethod(m: Method): CSharpMethod = { CSharpMethod( m.name, @@ -71,14 +73,23 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A CSharpField(f.name, f.typeFullName) } - val mapping = mutable.Map - .from(cpg.namespaceBlock.map { namespace => - namespace.fullName -> mutable.Set.from(namespace.typeDecl.map { typ => - CSharpType(typ.fullName, typ.method.map(toMethod).l, typ.member.map(toField).l) - }) - }) - .asInstanceOf[NamespaceToTypeMap] - CSharpProgramSummary(mapping, imports) + def toType(t: TypeDecl): CSharpType = { + CSharpType(t.fullName, t.method.map(toMethod).l, t.member.map(toField).l) + } + + val mapping = { + // TypeDecls found inside explicit namespace blocks + val withExplicitNamespace = cpg.namespaceBlock.map { namespace => + namespace.fullName -> mutable.Set.from(namespace.typeDecl.map(toType)) + } + + // TypeDecls found outside explicit namespace blocks + val withoutExplicitNamespace = Set("" -> mutable.Set.from(cpg.typeDecl.whereNot(_.namespaceBlock).map(toType))) + + mutable.Map.from(withExplicitNamespace ++ withoutExplicitNamespace) + } + + CSharpProgramSummary(mapping, imports, globalImports) } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpProgramSummary.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpProgramSummary.scala index 6ff1c0715021..1f756d2f6054 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpProgramSummary.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpProgramSummary.scala @@ -1,6 +1,8 @@ package io.joern.csharpsrc2cpg.datastructures +import better.files.File.VisitOptions import io.joern.csharpsrc2cpg.Constants +import io.joern.x2cpg.SourceFiles import io.joern.x2cpg.datastructures.{FieldLike, MethodLike, OverloadableMethod, ProgramSummary, TypeLike} import org.slf4j.LoggerFactory import upickle.core.LinkedHashMap @@ -8,7 +10,6 @@ import upickle.default.* import java.io.{ByteArrayInputStream, InputStream} import scala.annotation.targetName -import scala.collection.mutable.ListBuffer import scala.io.Source import scala.util.{Failure, Success, Try} import java.net.JarURLConnection @@ -25,16 +26,50 @@ type NamespaceToTypeMap = mutable.Map[String, mutable.Set[CSharpType]] * @see * [[CSharpProgramSummary.jsonToInitialMapping]] for generating initial mappings. */ -case class CSharpProgramSummary(val namespaceToType: NamespaceToTypeMap, val imports: Set[String]) +case class CSharpProgramSummary(namespaceToType: NamespaceToTypeMap, imports: Set[String], globalImports: Set[String]) extends ProgramSummary[CSharpType, CSharpMethod, CSharpField] { def findGlobalTypes: Set[CSharpType] = namespaceToType.getOrElse(Constants.Global, Set.empty).toSet @targetName("appendAll") def ++=(other: CSharpProgramSummary): CSharpProgramSummary = { - new CSharpProgramSummary(ProgramSummary.merge(namespaceToType, other.namespaceToType), imports ++ other.imports) + new CSharpProgramSummary( + ProgramSummary.merge(namespaceToType, other.namespaceToType), + imports ++ other.imports, + globalImports ++ other.globalImports + ) } + private def allImports: Set[String] = imports ++ globalImports + + def appendImported(other: CSharpProgramSummary): CSharpProgramSummary = + this ++= other.filter(namespacePred = (ns, _) => allImports.contains(ns)) + + /** Builds a new `CSharpProgramSummary` by filtering the current one's fields. + * + * @param namespacePred + * filtering predicate for `namespaceToType` + * + * @param importsPred + * filtering predicate for `imports` + * + * @param globalImportsPred + * filtering predicate for `globalImports` + */ + def filter( + namespacePred: (String, mutable.Set[CSharpType]) => Boolean = (_, _) => true, + importsPred: String => Boolean = _ => true, + globalImportsPred: String => Boolean = _ => true + ): CSharpProgramSummary = + copy( + namespaceToType = mutable.Map.fromSpecific(namespaceToType.view.filter(namespacePred(_, _))), + imports = imports.filter(importsPred), + globalImports = globalImports.filter(globalImportsPred) + ) + + def addGlobalImports(imports: Set[String]): CSharpProgramSummary = + copy(globalImports = globalImports ++ imports) + } object CSharpProgramSummary { @@ -45,9 +80,10 @@ object CSharpProgramSummary { def apply( namespaceToType: NamespaceToTypeMap = mutable.Map.empty, - imports: Set[String] = Set.empty + imports: Set[String] = Set.empty, + globalImports: Set[String] = Set.empty ): CSharpProgramSummary = - new CSharpProgramSummary(namespaceToType, imports) + new CSharpProgramSummary(namespaceToType, imports, globalImports) def apply(summaries: Iterable[CSharpProgramSummary]): CSharpProgramSummary = summaries.foldLeft(CSharpProgramSummary())(_ ++= _) @@ -57,7 +93,7 @@ object CSharpProgramSummary { /** @return * a mapping of the `System` package types. */ - def BuiltinTypes: NamespaceToTypeMap = { + private def BuiltinTypes: NamespaceToTypeMap = { jsonToInitialMapping(mergeBuiltInTypesJson) match { case Failure(exception) => logger.warn("Unable to parse JSON type entry from builtin types", exception); mutable.Map.empty @@ -65,6 +101,43 @@ object CSharpProgramSummary { } } + /** Returns the `CSharpProgramSummary` for the builtin types bundle. + */ + def builtinTypesSummary: CSharpProgramSummary = + CSharpProgramSummary(BuiltinTypes) + + /** Returns the `CSharpProgramSummary` for the given JSON file paths. + * + * @param paths + * the JSON file paths to load types from + */ + def externalTypesSummary(paths: Set[String]): CSharpProgramSummary = + CSharpProgramSummary(fromExternalJsons(paths)) + + private def fromExternalJsons(paths: Set[String]): NamespaceToTypeMap = { + val jsonFiles = paths.flatMap(SourceFiles.determine(_, Set(".json"))(VisitOptions.default)).toList + val inputStreams = jsonFiles.flatMap { path => + Try(java.io.FileInputStream(path)) match { + case Success(stream) => Some(stream) + case Failure(exc) => + logger.warn(s"Unable to open file: $path", exc) + None + } + } + + if (inputStreams.isEmpty) { + logger.warn("No JSON files found in the provided paths.") + mutable.Map.empty + } else { + jsonToInitialMapping(loadAndMergeJsonStreams(inputStreams)) match { + case Success(mapping) => mapping + case Failure(exception) => + logger.warn("Failed to parsed merged JSON streams", exception) + mutable.Map.empty + } + } + } + /** Converts a JSON type mapping to a NamespaceToTypeMap entry. * @param jsonInputStream * a JSON file as an input stream. @@ -74,6 +147,30 @@ object CSharpProgramSummary { def jsonToInitialMapping(jsonInputStream: InputStream): Try[NamespaceToTypeMap] = Try(read[NamespaceToTypeMap](ujson.Readable.fromByteArray(jsonInputStream.readAllBytes()))) + private def loadAndMergeJsonStreams(jsonInputStreams: List[InputStream]): InputStream = { + val jsonObjects = for { + inputStream <- jsonInputStreams + jsonString = Source.fromInputStream(inputStream).mkString + jsonObject = ujson.read(jsonString).obj + } yield jsonObject + + val mergedJson = jsonObjects + .reduceOption((prev, curr) => { + curr.keys.foreach(key => { + prev.updateWith(key) { + case Some(x) => + Option(x.arr.addAll(curr.get(key).get.arr)) + case None => + Option(curr.get(key).get.arr) + } + }) + prev + }) + .getOrElse(LinkedHashMap[String, ujson.Value]()) + + new ByteArrayInputStream(writeToByteArray(ujson.read(mergedJson))) + } + private def mergeBuiltInTypesJson: InputStream = { val classLoader = getClass.getClassLoader val builtinDirectory = "builtin_types" @@ -117,30 +214,7 @@ object CSharpProgramSummary { logger.warn("No JSON files found.") InputStream.nullInputStream() } else { - val mergedJsonObjects = ListBuffer[LinkedHashMap[String, ujson.Value]]() - for (resourcePath <- resourcePaths) { - val inputStream = classLoader.getResourceAsStream(resourcePath) - val jsonString = Source.fromInputStream(inputStream).mkString - val jsonObject = ujson.read(jsonString).obj - mergedJsonObjects.addOne(jsonObject) - } - - val mergedJson: LinkedHashMap[String, ujson.Value] = - mergedJsonObjects - .reduceOption((prev, curr) => { - curr.keys.foreach(key => { - prev.updateWith(key) { - case Some(x) => - Option(x.arr.addAll(curr.get(key).get.arr)) - case None => - Option(curr.get(key).get.arr) - } - }) - prev - }) - .getOrElse(LinkedHashMap[String, ujson.Value]()) - - new ByteArrayInputStream(writeToByteArray(ujson.read(mergedJson))) + loadAndMergeJsonStreams(resourcePaths.map(classLoader.getResourceAsStream)) } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala index 7ebe276c53c6..b600a5085ec8 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala @@ -1,17 +1,25 @@ package io.joern.csharpsrc2cpg.datastructures +import io.joern.csharpsrc2cpg.Constants +import io.joern.csharpsrc2cpg.utils.Utils import io.joern.x2cpg.Defines import io.joern.x2cpg.datastructures.{OverloadableScope, Scope, ScopeElement, TypedScope, TypedScopeElement} +import io.joern.x2cpg.utils.ListUtils.singleOrNone import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew +import io.joern.x2cpg.utils.ListUtils.singleOrNone import scala.collection.mutable +import scala.reflect.ClassTag class CSharpScope(summary: CSharpProgramSummary) extends Scope[String, DeclarationNew, TypedScopeElement] with TypedScope[CSharpMethod, CSharpField, CSharpType](summary) with OverloadableScope[CSharpMethod] { - override val typesInScope: mutable.Set[CSharpType] = mutable.Set.empty[CSharpType].addAll(summary.findGlobalTypes) + override val typesInScope: mutable.Set[CSharpType] = mutable.Set + .empty[CSharpType] + .addAll(summary.findGlobalTypes) + .addAll(summary.globalImports.flatMap(summary.namespaceToType.getOrElse(_, Set.empty))) /** @return * the surrounding type declaration if one exists. @@ -32,7 +40,7 @@ class CSharpScope(summary: CSharpProgramSummary) override def isOverloadedBy(method: CSharpMethod, argTypes: List[String]): Boolean = { method.parameterTypes - .filterNot(_._1 == "this") + .filterNot(_._1 == Constants.This) .map(_._2) .zip(argTypes) .count({ case (x, y) => x != y }) == 0 @@ -56,10 +64,16 @@ class CSharpScope(summary: CSharpProgramSummary) .exists(x => x.scopeNode.isInstanceOf[MethodScope] || x.scopeNode.isInstanceOf[TypeLikeScope]) override def tryResolveTypeReference(typeName: String): Option[CSharpType] = { - if (typeName == "this") { + if (typeName == Constants.This) { surroundingTypeDeclFullName.flatMap(summary.matchingTypes).headOption } else { - super.tryResolveTypeReference(typeName) + super.tryResolveTypeReference(typeName) match + case Some(x) => Some(x) + case None => + // typeName might be a fully-qualified name e.g. System.Console, in which case, even if we + // don't import System (i.e. System is not in typesInScope), we should still find it if it's + // in the type summaries and there's exactly 1 match. + Some(typeName).filter(_.contains(".")).flatMap(summary.matchingTypes.andThen(singleOrNone)) } } @@ -108,4 +122,64 @@ class CSharpScope(summary: CSharpProgramSummary) Option(top) } + /** Reduces [[typesInScope]] to contain only those types holding an extension method with the desired signature. + */ + private def extensionsInScopeFor( + extendedType: String, + callName: String, + argTypes: List[String] + ): mutable.Set[CSharpType] = { + typesInScope + .map(t => t.copy(methods = t.methods.filter(matchingExtensionMethod(extendedType, callName, argTypes)))) + .filter(_.methods.nonEmpty) + } + + /** Builds a predicate for matching [[CSharpMethod]] with an ad-hoc description of theirs. + */ + private def matchingExtensionMethod( + thisType: String, + name: String, + argTypes: List[String] + ): CSharpMethod => Boolean = { m => + // TODO: we should also compare argTypes, however we first need to account for: + // a) default valued parameters in CSharpMethod, to account for different arities + // b) compatible/sub types, i.e. System.String should unify with System.Object. + m.isStatic && m.name == name && m.parameterTypes.map(_._2).headOption.contains(thisType) + } + + /** Tries to find an extension method for [[baseTypeFullName]] with the given [[callName]] and [[argTypes]] in the + * types currently in scope. + * + * @param baseTypeFullName + * the extension method's `this` argument. + * @param callName + * the method name + * @param argTypes + * the method's argument types, excluding `this` + * @return + * the method metadata, together with the class name where it can be found + */ + def tryResolveExtensionMethodInvocation( + baseTypeFullName: Option[String], + callName: String, + argTypes: List[String] + ): Option[(CSharpMethod, String)] = { + baseTypeFullName.flatMap(extensionsInScopeFor(_, callName, argTypes).headOption).map(x => (x.methods.head, x.name)) + } + + def tryResolveGetterInvocation( + fieldIdentifierName: String, + baseTypeFullName: Option[String] + ): Option[CSharpMethod] = { + val getterMethodName = Utils.composeGetterName(fieldIdentifierName) + tryResolveMethodInvocation(getterMethodName, Nil, baseTypeFullName) + } + + def tryResolveSetterInvocation( + fieldIdentifierName: String, + baseTypeFullName: Option[String] + ): Option[CSharpMethod] = { + val setterMethodName = Utils.composeSetterName(fieldIdentifierName) + tryResolveMethodInvocation(setterMethodName, Defines.Any :: Nil, baseTypeFullName) + } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala index 67ebeadaa1a8..4ba9bd26586c 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala @@ -134,27 +134,29 @@ object DotNetJsonAst { object LogicalNotExpression extends UnaryExpr object AddressOfExpression extends UnaryExpr - sealed trait BinaryExpr extends BaseExpr - object AddExpression extends BinaryExpr - object SubtractExpression extends BinaryExpr - object MultiplyExpression extends BinaryExpr - object DivideExpression extends BinaryExpr - object ModuloExpression extends BinaryExpr - object EqualsExpression extends BinaryExpr - object NotEqualsExpression extends BinaryExpr - object LogicalAndExpression extends BinaryExpr - object LogicalOrExpression extends BinaryExpr - object AddAssignmentExpression extends BinaryExpr - object SubtractAssignmentExpression extends BinaryExpr - object MultiplyAssignmentExpression extends BinaryExpr - object DivideAssignmentExpression extends BinaryExpr - object ModuloAssignmentExpression extends BinaryExpr - object AndAssignmentExpression extends BinaryExpr - object OrAssignmentExpression extends BinaryExpr - object ExclusiveOrAssignmentExpression extends BinaryExpr - object RightShiftAssignmentExpression extends BinaryExpr - object LeftShiftAssignmentExpression extends BinaryExpr - object SimpleAssignmentExpression extends BinaryExpr + sealed trait BinaryExpr extends BaseExpr + object AddExpression extends BinaryExpr + object SubtractExpression extends BinaryExpr + object MultiplyExpression extends BinaryExpr + object DivideExpression extends BinaryExpr + object ModuloExpression extends BinaryExpr + object EqualsExpression extends BinaryExpr + object NotEqualsExpression extends BinaryExpr + object LogicalAndExpression extends BinaryExpr + object LogicalOrExpression extends BinaryExpr + + sealed trait AssignmentExpr extends BinaryExpr + object AddAssignmentExpression extends AssignmentExpr + object SubtractAssignmentExpression extends AssignmentExpr + object MultiplyAssignmentExpression extends AssignmentExpr + object DivideAssignmentExpression extends AssignmentExpr + object ModuloAssignmentExpression extends AssignmentExpr + object AndAssignmentExpression extends AssignmentExpr + object OrAssignmentExpression extends AssignmentExpr + object ExclusiveOrAssignmentExpression extends AssignmentExpr + object RightShiftAssignmentExpression extends AssignmentExpr + object LeftShiftAssignmentExpression extends AssignmentExpr + object SimpleAssignmentExpression extends AssignmentExpr object GreaterThanExpression extends BinaryExpr object LessThanExpression extends BinaryExpr @@ -226,6 +228,8 @@ object DotNetJsonAst { object ReturnStatement extends JumpStatement + object LocalFunctionStatement extends DeclarationExpr with BaseStmt + object AwaitExpression extends BaseExpr object PropertyDeclaration extends DeclarationExpr @@ -270,14 +274,28 @@ object DotNetJsonAst { object Attribute extends BaseExpr + object AttributeArgumentList extends BaseExpr + + object AttributeArgument extends BaseExpr + + object ParenthesizedExpression extends BaseExpr + object Unknown extends DotNetParserNode + object AccessorList extends DotNetParserNode + + object GetAccessorDeclaration extends DotNetParserNode + + object SetAccessorDeclaration extends DotNetParserNode + } /** The JSON key values, in alphabetical order. */ object ParserKeys { + val AccessorList = "AccessorList" + val Accessors = "Accessors" val AstRoot = "AstRoot" val Arguments = "Arguments" val ArgumentList = "ArgumentList" @@ -303,6 +321,7 @@ object ParserKeys { val ExpressionBody = "ExpressionBody" val Finally = "Finally" val FileName = "FileName" + val GetAccessorDeclaration = "GetAccessorDeclaration" val Identifier = "Identifier" val Incrementors = "Incrementors" val Initializer = "Initializer" @@ -325,6 +344,7 @@ object ParserKeys { val ParameterList = "ParameterList" val Pattern = "Pattern" val Sections = "Sections" + val SetAccessorDeclaration = "SetAccessorDeclaration" val SingleVariableDesignation = "SingleVariableDesignation" val Statement = "Statement" val Statements = "Statements" diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala index dcf6a21c4bcc..209a712c656c 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala @@ -45,8 +45,8 @@ class DependencyPass(cpg: Cpg, buildFiles: List[String], registerPackageId: Stri } val packageVersion = packageReference.attribute("Version").map(_.toString()).getOrElse("") val dependencyNode = NewDependency() - .name(packageName) - .version(packageVersion) + .name(packageName.trim()) + .version(packageVersion.trim()) builder.addNode(dependencyNode) } match { case Failure(exception) => diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala index b935912769b2..a58a8bf005a2 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala @@ -88,8 +88,13 @@ class DependencyDownloader( case Success(x) => x } - def createUrl(packageType: String, version: String): URL = { - URI(s"https://$NUGET_BASE_API_V2/$packageType/${dependencyName}/$version").toURL + def createUrl(packageType: String, version: String): Option[URL] = { + Try(new URI(s"https://$NUGET_BASE_API_V2/$packageType/${dependencyName}/$version").toURL) match { + case Success(url) => Some(url) + case Failure(e) => + logger.debug(s"Failed to create URL for packageType: $packageType, version: $version. Error: ${e.getMessage}") + None + } } // If dependency version is not specified, latest is returned @@ -115,50 +120,54 @@ class DependencyDownloader( * @return * the package version. */ - private def downloadPackage(targetDir: File, dependency: Dependency, url: URL): Unit = { + private def downloadPackage(targetDir: File, dependency: Dependency, url: Option[URL]): Unit = { var connection: Option[HttpURLConnection] = None - try { - connection = Option(url.openConnection()).collect { case x: HttpURLConnection => x } - // allow both GZip and Deflate (ZLib) encodings - connection.foreach(_.setRequestProperty("Accept-Encoding", "gzip, deflate")) - connection match { - case Some(conn: HttpURLConnection) if conn.getResponseCode == HttpURLConnection.HTTP_OK => - val ext = if url.toString.contains("/package/") then "nupkg" else "snupkg" - val fileName = targetDir / s"${dependency.name}.$ext" - - val inputStream = Option(conn.getContentEncoding) match { - case Some(encoding) if encoding.equalsIgnoreCase("gzip") => GZIPInputStream(conn.getInputStream) - case Some(encoding) if encoding.equalsIgnoreCase("deflate") => InflaterInputStream(conn.getInputStream) - case _ => conn.getInputStream - } - - Try { - Using.resources(inputStream, new FileOutputStream(fileName.pathAsString)) { (is, fos) => - val buffer = new Array[Byte](4096) - Iterator - .continually(is.read(buffer)) - .takeWhile(_ != -1) - .foreach(bytesRead => fos.write(buffer, 0, bytesRead)) - } - } match { - case Failure(exception) => - logger.error( - s"Exception occurred while downloading $fileName (${dependency.name}:${dependency.version})", - exception - ) - case Success(_) => - logger.info(s"Successfully downloaded dependency ${dependency.name}:${dependency.version}") + url.foreach { validUrl => + { + try { + connection = Option(validUrl.openConnection()).collect { case x: HttpURLConnection => x } + // allow both GZip and Deflate (ZLib) encodings + connection.foreach(_.setRequestProperty("Accept-Encoding", "gzip, deflate")) + connection match { + case Some(conn: HttpURLConnection) if conn.getResponseCode == HttpURLConnection.HTTP_OK => + val ext = if url.toString.contains("/package/") then "nupkg" else "snupkg" + val fileName = targetDir / s"${dependency.name}.$ext" + + val inputStream = Option(conn.getContentEncoding) match { + case Some(encoding) if encoding.equalsIgnoreCase("gzip") => GZIPInputStream(conn.getInputStream) + case Some(encoding) if encoding.equalsIgnoreCase("deflate") => InflaterInputStream(conn.getInputStream) + case _ => conn.getInputStream + } + + Try { + Using.resources(inputStream, new FileOutputStream(fileName.pathAsString)) { (is, fos) => + val buffer = new Array[Byte](4096) + Iterator + .continually(is.read(buffer)) + .takeWhile(_ != -1) + .foreach(bytesRead => fos.write(buffer, 0, bytesRead)) + } + } match { + case Failure(exception) => + logger.error( + s"Exception occurred while downloading $fileName (${dependency.name}:${dependency.version})", + exception + ) + case Success(_) => + logger.info(s"Successfully downloaded dependency ${dependency.name}:${dependency.version}") + } + case Some(conn: HttpURLConnection) => + logger.error(s"Connection to $url responded with non-200 code ${conn.getResponseCode}") + case _ => + logger.error(s"Unknown URL connection made, aborting") } - case Some(conn: HttpURLConnection) => - logger.error(s"Connection to $url responded with non-200 code ${conn.getResponseCode}") - case _ => - logger.error(s"Unknown URL connection made, aborting") + } catch { + case exception: Throwable => + logger.error(s"Unable to download dependency ${dependency.name}:${dependency.version}", exception) + } finally { + connection.foreach(_.disconnect()) + } } - } catch { - case exception: Throwable => - logger.error(s"Unable to download dependency ${dependency.name}:${dependency.version}", exception) - } finally { - connection.foreach(_.disconnect()) } } @@ -217,5 +226,4 @@ class DependencyDownloader( .map(CSharpProgramSummary(_)) CSharpProgramSummary(summaries) } - } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DotNetAstGenRunner.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DotNetAstGenRunner.scala index e5ac1153b08c..ff1793251ab8 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DotNetAstGenRunner.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DotNetAstGenRunner.scala @@ -16,8 +16,9 @@ class DotNetAstGenRunner(config: Config) extends AstGenRunnerBase(config) { private val logger = LoggerFactory.getLogger(getClass) // The x86 variant seems to run well enough on MacOS M-family chips, whereas the ARM build crashes - override val MacArm: String = MacX86 - override val WinArm: String = WinX86 + override val MacArm: String = MacX86 + override val WinArm: String = WinX86 + override val LinuxArm: String = "linux-arm64" override def fileFilter(file: String, out: File): Boolean = { file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match { @@ -30,11 +31,14 @@ class DotNetAstGenRunner(config: Config) extends AstGenRunnerBase(config) { override def skippedFiles(in: File, astGenOut: List[String]): List[String] = { val diagnosticMap = mutable.LinkedHashMap.empty[String, Seq[String]] - def addReason(reason: String, lastFile: Option[String] = None) = { - val key = lastFile.getOrElse(diagnosticMap.last._1) - diagnosticMap.updateWith(key) { - case Some(x) => Option(x :+ reason) - case None => Option(reason :: Nil) + def addReason(reason: String, lastFile: Option[String] = None): Unit = { + val key = lastFile.orElse(diagnosticMap.lastOption.map(_._1)) + + key.foreach { resolvedKey => + diagnosticMap.updateWith(resolvedKey) { + case Some(existingReasons) => Some(existingReasons :+ reason) + case None => Some(List(reason)) + } } } @@ -63,8 +67,8 @@ class DotNetAstGenRunner(config: Config) extends AstGenRunnerBase(config) { override def runAstGenNative(in: String, out: File, exclude: String, include: String)(implicit metaData: AstGenProgramMetaData ): Try[Seq[String]] = { - val excludeCommand = if (exclude.isEmpty) "" else s"-e \"$exclude\"" - ExternalCommand.run(s"$astGenCommand -o ${out.toString()} -i \"$in\" $excludeCommand", ".") + val excludeCommand = if (exclude.isEmpty) Seq.empty else Seq("-e", exclude) + ExternalCommand.run(Seq(astGenCommand, "-o", out.toString(), "-i", in) ++ excludeCommand, ".").toTry } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ImplicitUsingsCollector.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ImplicitUsingsCollector.scala new file mode 100644 index 000000000000..14c4c26ce2b2 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ImplicitUsingsCollector.scala @@ -0,0 +1,115 @@ +package io.joern.csharpsrc2cpg.utils + +import better.files.File +import io.joern.semanticcpg.utils.SecureXmlParsing + +import scala.xml.{Elem, Node} + +/** Depending on the project type defined in `.csproj` files, different sets of global usings are turned on by default. + * Here we collect them all. + */ +object ImplicitUsingsCollector { + + /** Collects implicit global imports extracted from `.csproj` files. + * + * @param buildFiles + * paths to `.csproj` files + * @return + * the list of implicitly turned on global imports + */ + def collect(buildFiles: List[String]): List[String] = { + buildFiles.flatMap { csproj => + SecureXmlParsing.parseXml(File(csproj).contentAsString) match + case Some(xml) => from(xml) + case None => List.empty + } + } + + // See https://learn.microsoft.com/en-gb/dotnet/core/project-sdk/overview#implicit-using-directives + private val projectTypeMapping: Map[String, List[String]] = { + val netSdkNamespace = List( + "System", + "System.Collections.Generic", + "System.IO", + "System.Linq", + "System.Net.Http", + "System.Threading", + "System.Threading.Tasks" + ) + Map( + "Microsoft.NET.Sdk" -> netSdkNamespace, + "Microsoft.NET.Sdk.Web" -> netSdkNamespace.appendedAll( + List( + "System.Net.Http.Json", + "Microsoft.AspNetCore.Builder", + "Microsoft.AspNetCore.Hosting", + "Microsoft.AspNetCore.Http", + "Microsoft.AspNetCore.Routing", + "Microsoft.Extensions.Configuration", + "Microsoft.Extensions.DependencyInjection", + "Microsoft.Extensions.Hosting", + "Microsoft.Extensions.Logging" + ) + ), + "Microsoft.NET.Sdk.Worker" -> netSdkNamespace.appendedAll( + List( + "Microsoft.Extensions.Configuration", + "Microsoft.Extensions.DependencyInjection", + "Microsoft.Extensions.Hosting", + "Microsoft.Extensions.Logging" + ) + ), + "Microsoft.NET.Sdk.WindowsDesktop" -> netSdkNamespace.appendedAll(List("System.Drawing", "System.Windows.Forms")) + ) + } + + /** Extracts implicit usings based on the project type, i.e. based on the `` tag. + */ + private def from(rootElem: Elem): List[String] = { + val projectType = rootElem.label match + case "Project" => rootElem.attribute("Sdk").flatMap(_.headOption.map(_.text)) + case _ => None + + val implicitUsingsEnabled = rootElem.child + .collect { case x if x.label == "PropertyGroup" => x.child } + .flatten + .collect { case x if x.label == "ImplicitUsings" => x.text } + .exists(x => x == "true" || x == "enable") + + val usingsFromProjectType = if (projectType.isDefined && implicitUsingsEnabled) { + projectTypeMapping.getOrElse(projectType.get, Nil) + } else { + Nil + } + + // Once we gather the initial set of implicit usings (if any) based on the project type, we + // process ItemGroup.Using tags. The order in which we process these matters, e.g. + // + // + // removes "System", whereas + // + // + // adds "System". + + rootElem.child + .collect { case x if x.label == "ItemGroup" => x.child } + .flatten + .collect { case x if x.label == "Using" => x } + .flatten + .foldLeft(usingsFromProjectType.toSet) { case (acc, node) => + if (node.attribute("Remove").isDefined) { + node.attribute("Remove").flatMap(_.headOption.map(_.text)) match + case None => acc + case Some(toRemove) => acc.excl(toRemove) + } else if (node.attribute("Include").isDefined) { + node.attribute("Include").flatMap(_.headOption.map(_.text)) match + case None => acc + case Some(toInclude) => acc + toInclude + } else { + acc + } + } + .toList + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ProgramSummaryCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ProgramSummaryCreator.scala new file mode 100644 index 000000000000..7d2ae042a7b4 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/ProgramSummaryCreator.scala @@ -0,0 +1,43 @@ +package io.joern.csharpsrc2cpg.utils + +import io.joern.csharpsrc2cpg.Config +import io.joern.csharpsrc2cpg.astcreation.AstCreator +import io.joern.csharpsrc2cpg.datastructures.CSharpProgramSummary +import io.joern.x2cpg.utils.ConcurrentTaskUtil +import org.slf4j.LoggerFactory + +import scala.util.{Failure, Success} + +/** Builds a `CSharpProgramSummary` by pre-parsing AST creators for high level structures, taking into account related + * frontend options. + */ +object ProgramSummaryCreator { + + private val logger = LoggerFactory.getLogger(getClass) + + def from(astCreators: Seq[AstCreator], config: Config): CSharpProgramSummary = { + val internalSummary = summarizeAstCreators(astCreators) + val externalSummary = buildExternalSummary(config.useBuiltinSummaries, config.externalSummaryPaths) + internalSummary.appendImported(externalSummary) + } + + private def summarizeAstCreators(astCreators: Seq[AstCreator]): CSharpProgramSummary = { + ConcurrentTaskUtil + .runUsingThreadPool(astCreators.map(x => () => x.summarize()).iterator) + .flatMap { + case Failure(exception) => + logger.warn(s"Unable to pre-parse C# file, skipping - ", exception) + None + case Success(summary) => Option(summary) + } + .foldLeft(CSharpProgramSummary(imports = CSharpProgramSummary.initialImports))(_ ++= _) + } + + private def buildExternalSummary(withBuiltinTypes: Boolean, withJsonFiles: Set[String]): CSharpProgramSummary = { + val builtin = if withBuiltinTypes then CSharpProgramSummary.builtinTypesSummary else CSharpProgramSummary() + val fromJson = + if withJsonFiles.nonEmpty then CSharpProgramSummary.externalTypesSummary(withJsonFiles) + else CSharpProgramSummary() + builtin ++= fromJson + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/Utils.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/Utils.scala index f2acc6347d4b..e71cb46e3d2b 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/Utils.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/Utils.scala @@ -1,12 +1,42 @@ package io.joern.csharpsrc2cpg.utils +import io.joern.x2cpg.Ast +import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn + object Utils { def composeMethodLikeSignature(returnType: String, parameterTypes: collection.Seq[String]): String = { s"$returnType(${parameterTypes.mkString(",")})" } + def composeMethodLikeSignature(returnType: String, parameters: Seq[Ast] = Nil): String = { + composeMethodLikeSignature( + returnType, + parameters.flatMap(_.nodes.collectFirst { case x: NewMethodParameterIn => x.typeFullName }) + ) + } + def composeMethodFullName(typeDeclFullName: String, name: String, signature: String): String = { s"$typeDeclFullName.$name:$signature" } + def composeGetterName(fieldIdentifierName: String): String = s"get_$fieldIdentifierName" + + def composeSetterName(fieldIdentifierName: String): String = s"set_$fieldIdentifierName" + + /** Generates the fictitious class name that holds top-level statements. + */ + def composeTopLevelClassName(fileName: String): String = { + val sanitizedFileName = fileName.replace(java.io.File.separator, "_").replace(".", "_") + s"${sanitizedFileName}_Program" + } + + /** Strips the signature part from [[fullName]]. + * + * Useful when handling nested methods, as method full names include signatures. To avoid a nested method's full name + * containing both its parent's signature and its own, we remove the parent's signature when entering its scope. + */ + def withoutSignature(fullName: String): String = fullName.split(':').toList match + case fn :: sig :: Nil => fn + case _ => fullName + } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/CSharp2CpgHTTPServerTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/CSharp2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..639704906c97 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/CSharp2CpgHTTPServerTests.scala @@ -0,0 +1,78 @@ +package io.joern.csharpsrc2cpg.io + +import better.files.File +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class CSharp2CpgHTTPServerTests extends CSharpCode2CpgFixture with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("csharp2cpgTestsHttpTest") + val file = dir / "main.cs" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(basicBoilerplate(s"Console.WriteLine($indexStr);")) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.csharpsrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.csharpsrc2cpg.Main.stop() + } + + "Using csharp2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("csharp2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("Main") + cpg.call.code.l shouldBe List("Console.WriteLine()") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("csharp2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("Main") + cpg.call.code.l shouldBe List(s"Console.WriteLine($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/ProjectParseTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/ProjectParseTests.scala index 1dc9312deaad..bd00ad8a77bc 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/ProjectParseTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/io/ProjectParseTests.scala @@ -1,11 +1,11 @@ package io.joern.csharpsrc2cpg.io import better.files.File -import io.joern.csharpsrc2cpg.datastructures.CSharpProgramSummary +import io.joern.csharpsrc2cpg.CSharpSrc2Cpg +import io.joern.csharpsrc2cpg.Config import io.joern.csharpsrc2cpg.passes.AstCreationPass import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture import io.joern.csharpsrc2cpg.utils.DotNetAstGenRunner -import io.joern.csharpsrc2cpg.{CSharpSrc2Cpg, Config} import io.joern.x2cpg.X2Cpg.newEmptyCpg import io.joern.x2cpg.utils.Report import io.shiftleft.codepropertygraph.generated.Cpg diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/passes/DependencyTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/passes/DependencyTests.scala index 488baaa5a695..f3310796d2c1 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/passes/DependencyTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/passes/DependencyTests.scala @@ -3,6 +3,7 @@ package io.joern.csharpsrc2cpg.passes import io.joern.csharpsrc2cpg.Config import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture import io.shiftleft.semanticcpg.language.* +import io.shiftleft.utils.ProjectRoot class DependencyTests extends CSharpCode2CpgFixture { @@ -101,7 +102,73 @@ class DependencyTests extends CSharpCode2CpgFixture { fail("Expected a call node for `Entity`") } } + } + + "a `csproj` file specifying a built-in dependency but built-in type summaries are disabled" when { + val csCode = """ + |using Microsoft.EntityFrameworkCore; + | + |public class Foo + |{ + | static void bar(ModelBuilder modelBuilder) + | { + | modelBuilder.Entity("test"); + | } + |}""".stripMargin + val csProj = """ + | + | + | + | + | + |""".stripMargin + + "the ability to download dependencies is also turned off" should { + val cpg = code(csCode) + .moreCode(csProj, "Foo.csproj") + .withConfig(Config().withUseBuiltinSummaries(false).withDownloadDependencies(false)) + "not resolve the call since there are no type summaries available for it" in { + inside(cpg.call("Entity").headOption) { + case Some(entity) => entity.methodFullName shouldBe "ModelBuilder.Entity:" + case None => fail("Expected call node for `Entity`") + } + } + } + "the ability to download dependencies is turned on" should { + val cpg = code(csCode) + .moreCode(csProj, "Foo.csproj") + .withConfig(Config().withUseBuiltinSummaries(false).withDownloadDependencies(true)) + "resolve the call since the dependency shall be downloaded and a type summary for it be built" in { + inside(cpg.call("Entity").headOption) { + case Some(entity) => + entity.methodFullName shouldBe "Microsoft.EntityFrameworkCore.ModelBuilder.Entity:Microsoft.EntityFrameworkCore.Metadata.Builders.EntityTypeBuilder(System.String)" + case None => fail("Expected call node for `Entity`") + } + } + } + + "download dependencies is disabled but external-summary-paths is pointing to the built-in directory" should { + val externalSummaryPaths = + Set(ProjectRoot.relativise("joern-cli/frontends/csharpsrc2cpg/src/main/resources/builtin_types")) + val cpg = code(csCode) + .moreCode(csProj, "Foo.csproj") + .withConfig( + Config() + .withDownloadDependencies(false) + .withUseBuiltinSummaries(false) + .withExternalSummaryPaths(externalSummaryPaths) + ) + + "resolve the call since its summary can be found in the provided directory" in { + inside(cpg.call("Entity").headOption) { + case Some(entity) => + entity.methodFullName shouldBe "Microsoft.EntityFrameworkCore.ModelBuilder.Entity:Microsoft.EntityFrameworkCore.Metadata.Builders.EntityTypeBuilder(System.String)" + case None => fail("Expected call node for `Entity`") + } + } + + } } "a `csproj` file specifying a dependency with the `Update` attribute" should { @@ -160,9 +227,35 @@ class DependencyTests extends CSharpCode2CpgFixture { .withConfig(config) inside(cpg.dependency.l) { case dep :: Nil => - dep.name shouldBe " System.Security.Cryptography.Pkcs" + dep.name shouldBe "System.Security.Cryptography.Pkcs" dep.version shouldBe "6.0.4" } } } + + "a csproj file specifying a dependency with whitespaces in between" should { + val config = Config().withDownloadDependencies(true); + "not throw a exception" in { + val cpg = code(""" + |namespace Foo; + |""".stripMargin) + .moreCode( + """ + | + | + | + | + | + | + |""".stripMargin, + "Foo.csproj" + ) + .withConfig(config) + + inside(cpg.dependency.l) { case dep :: dep2 :: Nil => + dep.name shouldBe "System.Security.Cryptography.Pkcs" + dep.version shouldBe "6 .0.4" + } + } + } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ClassDeclarationTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ClassDeclarationTests.scala new file mode 100644 index 000000000000..975b891259f6 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ClassDeclarationTests.scala @@ -0,0 +1,44 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.semanticcpg.language.* + +class ClassDeclarationTests extends CSharpCode2CpgFixture { + + "empty abstract class" should { + val cpg = code(""" + |abstract class C {} + |""".stripMargin) + + "have correct modifiers" in { + cpg.typeDecl.nameExact("C").modifier.modifierType.sorted.l shouldBe List( + ModifierTypes.ABSTRACT, + ModifierTypes.INTERNAL + ) + } + } + + "class with member of the same type involved in a fieldAccess" should { + val cpg = code(""" + |namespace Foo; + |class Bar + |{ + | Bar Field; + | void DoStuff() + | { + | var x = this.Field; + | } + |} + |""".stripMargin) + + "have correct typeDecl properties" in { + cpg.typeDecl.nameExact("Bar").fullName.l shouldBe List("Foo.Bar") + } + + "have correct member properties" in { + cpg.typeDecl.nameExact("Bar").member.nameExact("Field").typeFullName.l shouldBe List("Foo.Bar") + } + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ControlStructureTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ControlStructureTests.scala index f3f3d6d0180d..ee53365e504b 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ControlStructureTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ControlStructureTests.scala @@ -173,4 +173,37 @@ class ControlStructureTests extends CSharpCode2CpgFixture { } + "a variable defined within a using statement" should { + val cpg = code(""" + |namespace other + |{ + | public class General + | { + | public static void Call(string name) + | { + | using (SqlConnection connection = new SqlConnection(name)) + | { + | try + | { + | connection.Open(); + | } + | catch (Exception ex) + | { + | Console.WriteLine(ex.Message); + | connection.Close(); + | } + | } + | } + | } + |} + |""".stripMargin) + + "partially resolve calls on the defined variable" in { + inside(cpg.call.name("Open").methodFullName.l) { + case x :: Nil => x shouldBe "SqlConnection.Open:" + case _ => fail("Unexpected call node structure") + } + } + } + } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala new file mode 100644 index 000000000000..1a8d25141f3c --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala @@ -0,0 +1,257 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, ModifierTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} +import io.shiftleft.semanticcpg.language.* + +class ExtensionMethodTests extends CSharpCode2CpgFixture { + + "nullary extension-method declaration" should { + val cpg = code(""" + |class MyClass {} + |static class Extensions + |{ + | public static void DoStuff(this MyClass myClass) {} + |} + |""".stripMargin) + + "have correct properties" in { + inside(cpg.method.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.fullName shouldBe "Extensions.DoStuff:System.Void(MyClass)" + doStuff.signature shouldBe "System.Void(MyClass)" + doStuff.methodReturn.typeFullName shouldBe "System.Void" + doStuff.modifier.modifierType.toSet shouldBe Set(ModifierTypes.STATIC, ModifierTypes.PUBLIC) + case xs => fail(s"Expected single DoStuff method, but got $xs") + } + } + + "have correct parameters" in { + inside(cpg.method.nameExact("DoStuff").parameter.sortBy(_.index).l) { + case myClass :: Nil => + myClass.typeFullName shouldBe "MyClass" + myClass.code shouldBe "this MyClass myClass" + myClass.name shouldBe "myClass" + case xs => fail(s"Expected single parameter, but got $xs") + } + } + } + + "nullary extension-method call" should { + val cpg = code(""" + |var x = new MyClass(); + |x.DoStuff(); + | + |class MyClass {} + |static class Extensions + |{ + | public static void DoStuff(this MyClass myClass) {} + |} + |""".stripMargin) + + "have correct properties" in { + inside(cpg.call.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.code shouldBe "x.DoStuff()" + doStuff.methodFullName shouldBe "Extensions.DoStuff:System.Void(MyClass)" + doStuff.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Expected single DoStuff call, but got $xs") + } + } + + "have correct arguments" in { + inside(cpg.call.nameExact("DoStuff").argument.sortBy(_.argumentIndex).l) { + case (x: Identifier) :: Nil => + x.argumentIndex shouldBe 0 + x.name shouldBe "x" + x.typeFullName shouldBe "MyClass" + x.code shouldBe "x" + case xs => fail(s"Expected single identifier argument to DoStuff, but got $xs") + } + } + } + + "two same-named extension methods in different namespaces" should { + val cpg = code(""" + |using Version1; + |var x = new MyClass(); + |x.DoStuff(0); + | + |class MyClass {} + |""".stripMargin) + .moreCode(""" + |namespace Version1; + | + |static class Extension1 + |{ + | public static void DoStuff(this MyClass myClass, int z) {} + |} + |""".stripMargin) + .moreCode(""" + |namespace Version2; + | + |static class Extension2 + |{ + | public static void DoStuff(this MyClass myClass, int z) {} + |} + |""".stripMargin) + + "find the correct extension method" in { + inside(cpg.call.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.code shouldBe "x.DoStuff(0)" + doStuff.methodFullName shouldBe "Version1.Extension1.DoStuff:System.Void(MyClass,System.Int32)" + case xs => fail(s"Expected single DoStuff call, but got $xs") + } + } + } + + "two same-named extension methods involving explicit sub-types" should { + + "map to the compile-time type (1)" in { + val cpg = code(""" + |var x = new MyConcrete(); + |x.DoStuff(); + | + |abstract class MyAbstract; + |class MyConcrete : MyAbstract; + | + |static class Extensions + |{ + | public static int DoStuff(this MyAbstract myAbstract) => 1; + | public static int DoStuff(this MyConcrete myConcrete) => 2; + |} + |""".stripMargin) + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(MyConcrete)") + } + + "map to the compile-time type (2)" in { + val cpg = code(""" + |MyAbstract x = new MyConcrete(); + |x.DoStuff(); + | + |abstract class MyAbstract; + |class MyConcrete : MyAbstract; + | + |static class Extensions + |{ + | public static int DoStuff(this MyAbstract myAbstract) => 1; + | public static int DoStuff(this MyConcrete myConcrete) => 2; + |} + |""".stripMargin) + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(MyAbstract)") + } + } + + "calling an extension method for `List`" should { + + "resolve correctly if the receiver is of type `List`" in { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) => 1; + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(List)") + } + + "resolve correctly if there's only 1 type-parametric extension for `List`" in { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) => 1; + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(List)") + } + + // TODO: The two `DoStuff` methods have the same methodFullName. + "resolve correctly if there are 2 possible extensions, one for `List` and another for `List`" ignore { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) { return 1; } + | public static int DoStuff(this List myList) { return 2; } + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").callee.l shouldBe cpg.literal("2").method.l + } + + "resolve correctly if the extra argument is type-compatible with the extension method's extra parameter" in { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(null); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList, object x) { return 2; } + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").callee.l shouldBe cpg.literal("2").method.l + } + } + + "consecutive unary extension method calls" should { + val cpg = code(""" + |var x = new MyClass(); + |var y = x.Foo().Bar(); + | + |class MyClass {} + |static class Extensions + |{ + | public static MyClass Foo(this MyClass c) => c; + | public static int Bar(this MyClass c) => 1; + |} + |""".stripMargin) + + "have correct properties and arguments" in { + inside(cpg.call.nameExact("Bar").l) { + case bar :: Nil => + bar.code shouldBe "x.Foo().Bar()" + bar.methodFullName shouldBe "Extensions.Bar:System.Int32(MyClass)" + bar.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(bar.argument.sortBy(_.argumentIndex).l) { + case (foo: Call) :: Nil => + foo.code shouldBe "x.Foo()" + foo.methodFullName shouldBe "Extensions.Foo:MyClass(MyClass)" + foo.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(foo.argument.sortBy(_.argumentIndex).l) { + case (x: Identifier) :: Nil => + x.code shouldBe "x" + x.name shouldBe "x" + x.typeFullName shouldBe "MyClass" + case xs => fail(s"Expected identifier argument to Foo, but got $xs") + } + case xs => fail(s"Expected single call argument to Bar, but got $xs") + } + case xs => fail(s"Expected single call to Bar, but got $xs") + } + } + + "have correct properties for the result of the chained call" in { + cpg.assignment.target.isIdentifier.nameExact("y").typeFullName.l shouldBe List("System.Int32") + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/FieldAccessTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/FieldAccessTests.scala new file mode 100644 index 000000000000..726a2b81478d --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/FieldAccessTests.scala @@ -0,0 +1,137 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.semanticcpg.language.* + +class FieldAccessTests extends CSharpCode2CpgFixture { + + "Console.WriteLine call while importing System" should { + val cpg = code(""" + |using System; + |Console.WriteLine("foo"); + |""".stripMargin) + + "have WriteLine call correctly set" in { + inside(cpg.call.nameExact("WriteLine").l) { + case writeLine :: Nil => + writeLine.code shouldBe "Console.WriteLine(\"foo\")" + writeLine.methodFullName shouldBe "System.Console.WriteLine:System.Void(System.String)" + case xs => fail(s"Expected single WriteLine call, but got $xs") + } + } + + "have foo literal correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(1).isLiteral.l) { + case foo :: Nil => + foo.typeFullName shouldBe "System.String" + foo.code shouldBe "\"foo\"" + case xs => fail(s"Expected single literal argument to WriteLine, but got $xs") + } + } + + "have Console correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(0).isIdentifier.l) { + case console :: Nil => + console.code shouldBe "Console" + console.typeFullName shouldBe "System.Console" + case xs => fail(s"Expected single Console identifier, but got $xs") + } + } + } + + "System.Console.WriteLine call while importing System" should { + val cpg = code(""" + |using System; + |System.Console.WriteLine("foo"); + |""".stripMargin) + + "have WriteLine call correctly set" in { + inside(cpg.call.nameExact("WriteLine").l) { + case writeLine :: Nil => + writeLine.methodFullName shouldBe "System.Console.WriteLine:System.Void(System.String)" + writeLine.code shouldBe "System.Console.WriteLine(\"foo\")" + case xs => fail(s"Expected single WriteLine call, but got $xs") + } + } + + "have foo literal correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(1).isLiteral.l) { + case foo :: Nil => + foo.typeFullName shouldBe "System.String" + foo.code shouldBe "\"foo\"" + case xs => fail(s"Expected single literal argument to WriteLine, but got $xs") + } + } + + "have System.Console correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(0).fieldAccess.l) { + case sysConsole :: Nil => + sysConsole.typeFullName shouldBe "System.Console" + sysConsole.code shouldBe "System.Console" + sysConsole.fieldIdentifier.code.l shouldBe List("Console") + sysConsole.fieldIdentifier.canonicalName.l shouldBe List("Console") + case xs => fail(s"Expected single fieldAccess to the left of WriteLine, but got $xs") + } + } + } + + "System.Console.WriteLine call without importing System" should { + val cpg = code(""" + |System.Console.WriteLine("foo"); + |""".stripMargin) + + "have WriteLine call correctly set" in { + inside(cpg.call.nameExact("WriteLine").l) { + case writeLine :: Nil => + writeLine.methodFullName shouldBe "System.Console.WriteLine:System.Void(System.String)" + writeLine.code shouldBe "System.Console.WriteLine(\"foo\")" + case xs => fail(s"Expected single WriteLine call, but got $xs") + } + } + + "have foo literal correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(1).isLiteral.l) { + case foo :: Nil => + foo.typeFullName shouldBe "System.String" + foo.code shouldBe "\"foo\"" + case xs => fail(s"Expected single literal argument to WriteLine, but got $xs") + } + } + + "have System.Console correctly set" in { + inside(cpg.call.nameExact("WriteLine").argument(0).fieldAccess.l) { + case sysConsole :: Nil => + sysConsole.typeFullName shouldBe "System.Console" + sysConsole.code shouldBe "System.Console" + sysConsole.fieldIdentifier.code.l shouldBe List("Console") + sysConsole.fieldIdentifier.canonicalName.l shouldBe List("Console") + case xs => fail(s"Expected single fieldAccess to the left of WriteLine, but got $xs") + } + } + } + + "field access via explicit `this.X`" should { + val cpg = code(""" + |using System; + |class C + |{ + | int x; + | C() + | { + | Console.WriteLine(this.x); + | } + |}""".stripMargin) + "have correct type for `this.x`" in { + inside(cpg.call("WriteLine").argument(1).fieldAccess.l) { + case fieldAccess :: Nil => + fieldAccess.code shouldBe "this.x" + fieldAccess.typeFullName shouldBe "System.Int32" + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + fieldAccess.referencedMember.l shouldBe cpg.typeDecl.nameExact("C").member.nameExact("x").l + case xs => fail(s"Expected single fieldAccess, but got $xs") + } + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ImplicitUsingsTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ImplicitUsingsTests.scala new file mode 100644 index 000000000000..ac5ce4015941 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ImplicitUsingsTests.scala @@ -0,0 +1,232 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class ImplicitUsingsTests extends CSharpCode2CpgFixture { + + "top-level WriteLine call" when { + + "accompanied by a NET.Sdk csproj with ImplicitUsings set to `enable`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | enable + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + "System.Console.WriteLine:System.Void(System.String)" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings set to `true`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | true + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + "System.Console.WriteLine:System.Void(System.String)" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings omitted" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "not resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + ".WriteLine:" + ) + } + + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings set to `false`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | false + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "not resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + ".WriteLine:" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings set to `disable`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | disable + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "not resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + ".WriteLine:" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings enabled but excluding `System`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | enable + | + | + | + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "not resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + ".WriteLine:" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings disabled but including `System`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | false + | + | + | + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + "System.Console.WriteLine:System.Void(System.String)" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings enabled but including and excluding `System`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | true + | + | + | + | + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "not resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + ".WriteLine:" + ) + } + } + + "accompanied by a NET.Sdk csproj with ImplicitUsings enabled but excluding and including `System`" should { + val cpg = code(""" + |Console.WriteLine("Foo"); + |""".stripMargin) + .moreCode( + """ + | + | + | Exe + | true + | + | + | + | + | + | + |""".stripMargin, + fileName = "App.csproj" + ) + + "resolve WriteLine call" in { + cpg.call.nameExact("WriteLine").methodFullName.l shouldBe List( + "System.Console.WriteLine:System.Void(System.String)" + ) + } + } + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/InheritanceFullNameTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/InheritanceFullNameTests.scala index d8476ab3879d..c06c6ad72e07 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/InheritanceFullNameTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/InheritanceFullNameTests.scala @@ -62,13 +62,13 @@ class InheritanceFullNameTests extends CSharpCode2CpgFixture { inside(qux.astChildren.isMethod.l) { case bazz :: Nil => - bazz.fullName shouldBe "HelloWorld.Qux.bazz:void()" + bazz.fullName shouldBe "HelloWorld.Qux.bazz:System.Void()" qux.fullName shouldBe "HelloWorld.Qux" qux.inheritsFromTypeFullName shouldBe Seq("HelloWorld.Foo") inside(qux.astChildren.isMethod.l) { case bazz :: Nil => - bazz.fullName shouldBe "HelloWorld.Qux.bazz:void()" + bazz.fullName shouldBe "HelloWorld.Qux.bazz:System.Void()" case _ => fail("There is no method named `baz` under `Qux` interface,") } case _ => fail("There is no interface named `Qux`") diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LambdaTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LambdaTests.scala index 6b6709784b59..b303d03fbd90 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LambdaTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LambdaTests.scala @@ -19,7 +19,7 @@ class LambdaTests extends CSharpCode2CpgFixture { inside(cpg.method("Main").astChildren.collectAll[Method].l) { case anon :: Nil => anon.name shouldBe "0" - anon.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anon.fullName shouldBe "HelloWorld.Program.Main.0:" inside(anon.parameter.l) { case x :: Nil => @@ -37,7 +37,7 @@ class LambdaTests extends CSharpCode2CpgFixture { inside(cpg.method("Main").astChildren.collectAll[TypeDecl].l) { case anon :: Nil => anon.name shouldBe "0" - anon.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anon.fullName shouldBe "HelloWorld.Program.Main.0:" case xs => fail(s"Expected a single anonymous type declaration, got [${xs.code.mkString(",")}]") } } @@ -48,7 +48,7 @@ class LambdaTests extends CSharpCode2CpgFixture { numbers.name shouldBe "numbers" numbers.typeFullName shouldBe s"${DotNetTypeMap(BuiltinTypes.Int)}[]" - closure.methodFullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + closure.methodFullName shouldBe "HelloWorld.Program.Main.0:" closure.referencedMethod.name shouldBe "0" case xs => fail(s"Expected two `Select` call argument, got [${xs.code.mkString(",")}]") } @@ -69,7 +69,7 @@ class LambdaTests extends CSharpCode2CpgFixture { inside(cpg.method("Main").astChildren.collectAll[Method].l) { case anon :: Nil => anon.name shouldBe "0" - anon.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anon.fullName shouldBe "HelloWorld.Program.Main.0:" inside(anon.parameter.l) { case x :: y :: Nil => @@ -91,7 +91,7 @@ class LambdaTests extends CSharpCode2CpgFixture { inside(cpg.method("Main").astChildren.collectAll[TypeDecl].l) { case anon :: Nil => anon.name shouldBe "0" - anon.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anon.fullName shouldBe "HelloWorld.Program.Main.0:" case xs => fail(s"Expected a single anonymous type declaration, got [${xs.code.mkString(",")}]") } } @@ -102,7 +102,7 @@ class LambdaTests extends CSharpCode2CpgFixture { numbers.name shouldBe "numbers" numbers.typeFullName shouldBe s"${DotNetTypeMap(BuiltinTypes.Int)}[]" - closure.methodFullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + closure.methodFullName shouldBe "HelloWorld.Program.Main.0:" closure.referencedMethod.name shouldBe "0" case xs => fail(s"Expected two `Select` call argument, got [${xs.code.mkString(",")}]") } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LoopsTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LoopsTests.scala index eecace4df9f1..d932655a6771 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LoopsTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/LoopsTests.scala @@ -1,10 +1,11 @@ package io.joern.csharpsrc2cpg.querying.ast import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Local} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} import io.shiftleft.semanticcpg.language.* -class LoopsTests extends CSharpCode2CpgFixture { +class LoopsTests extends CSharpCode2CpgFixture(withDataFlow = true) { "AST Creation for loops" should { "be correct for foreach statement" in { val cpg = code(basicBoilerplate(""" @@ -19,22 +20,28 @@ class LoopsTests extends CSharpCode2CpgFixture { case forEachNode :: Nil => forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR - inside(forEachNode.astChildren.isIdentifier.l) { - case iteratorNode :: iterableNode :: Nil => - iteratorNode.code shouldBe "element" - iteratorNode.typeFullName shouldBe "System.Int32" + inside(forEachNode.astChildren.l) { + case (idxLocal: Local) :: (elementLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil => + idxLocal.name shouldBe "_idx_" + idxLocal.typeFullName shouldBe "System.Int32" - iterableNode.code shouldBe "fibNumbers" - // TODO: List will be fully qualified once the System types are known - iterableNode.typeFullName shouldBe "List" - case _ => fail("No node for iterable found in `foreach` statement") - } + elementLocal.name shouldBe "element" + elementLocal.typeFullName shouldBe "System.Int32" + + initAssign.code shouldBe "_idx_ = 0" + initAssign.name shouldBe Operators.assignment + initAssign.methodFullName shouldBe Operators.assignment + + cond.code shouldBe "_idx_ < fibNumbers.Count" + cond.name shouldBe Operators.lessThan + cond.methodFullName shouldBe Operators.lessThan + + update.code shouldBe "element = fibNumbers[_idx_++]" + update.name shouldBe Operators.assignment + update.methodFullName shouldBe Operators.assignment - inside(forEachNode.astChildren.isBlock.l) { - case blockNode :: Nil => val List(writeCall) = cpg.call.nameExact("Write").l - writeCall.astParent shouldBe blockNode - case _ => fail("Correct blockNode as child not found for `foreach` statement") + writeCall.astParent shouldBe forBlock } case _ => fail("No control structure node found for `foreach`.") diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala index e6bee7b83a39..449465e0feef 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala @@ -6,7 +6,12 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} import io.shiftleft.semanticcpg.language.* class MemberAccessTests extends CSharpCode2CpgFixture { - "conditional member access expressions" should { + + // TODO: This test-case relies on the usage of getters, that are currently being + // reworked to be METHODs instead of MEMBERs. In particular, `bar?.Qux` should + // resemble `bar.get_Qux()`. We need to adapt astForMemberBindingExpression + // to accommodate this. + "conditional property access expressions" ignore { val cpg = code(""" |namespace Foo { | public class Baz { @@ -35,6 +40,85 @@ class MemberAccessTests extends CSharpCode2CpgFixture { } } + "conditional member access expressions" should { + val cpg = code(""" + |namespace Foo { + | public class Baz { + | public int Qux; + | } + | public class Bar { + | public static void Main() { + | var baz = new Baz(); + | var a = baz?.Qux; + | } + | } + |} + |""".stripMargin) + + "have correct types both on the LHS and RHS" in { + inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1)) { + case a :: Nil => + inside(a.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + case _ => fail("Expected 2 arguments under the assignment call.") + } + case _ => fail("Expected 1 assignment call.") + } + } + } + + "chained field access expression referencing members of a sibling class" should { + val cpg = code(""" + |namespace Foo; + |class Bar + |{ + | Bar Field1; + | Bar Field2; + |} + |class Baz + |{ + | static void DoStuff() + | { + | var x = new Bar(); + | var y = x.Field1.Field2; + | } + |} + | + |""".stripMargin) + + "have correct typeDecls" in { + cpg.typeDecl.nameExact("Bar").size shouldBe 1 + } + + "have correct properties for the innermost member" in { + cpg.typeDecl.nameExact("Bar").member.nameExact("Field1").typeFullName.l shouldBe List("Foo.Bar") + } + + "have correct properties for the outermost member" in { + cpg.typeDecl.nameExact("Bar").member.nameExact("Field2").typeFullName.l shouldBe List("Foo.Bar") + } + + "have correct properties for the outermost field access" in { + inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Field2")).l) { + case field2 :: Nil => + field2.typeFullName shouldBe "Foo.Bar" + field2.referencedMember.l shouldBe cpg.typeDecl.nameExact("Bar").member.nameExact("Field2").l + case _ => fail("Expected single field access to `Field2`") + } + } + + "have correct properties for the innermost field access" in { + inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Field1")).l) { + case field1 :: Nil => + field1.typeFullName shouldBe "Foo.Bar" + field1.referencedMember.l shouldBe cpg.typeDecl.nameExact("Bar").member.nameExact("Field1").l + case _ => fail("Expected single field access to `Field1`") + } + } + + } "conditional method access expressions" should { val cpg = code(""" |namespace Foo { @@ -216,7 +300,11 @@ class MemberAccessTests extends CSharpCode2CpgFixture { } } - "conditional method access expression for chained fields" should { + // TODO: ConditionalAccessExpressions need some work to deal with nested chains. + // This particular test-case relies on the usage of getters, that are currently being + // reworked to be METHODs instead of MEMBERs. + // Revisit this test-case once getters are finished. + "conditional property access expression for chained fields" ignore { val cpg = code(""" |namespace Foo { | public class Baz { @@ -245,6 +333,35 @@ class MemberAccessTests extends CSharpCode2CpgFixture { } } + "conditional method access expression for chained fields" should { + val cpg = code(""" + |namespace Foo { + | public class Baz { + | public Baz Qux; + | } + | public class Bar { + | public static void Main() { + | var baz = new Baz(); + | var b = baz?.Qux?.Qux; + | } + | } + |} + |""".stripMargin) + + "have correct types and attributes both on the LHS and RHS" in { + inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1).l) { + case a :: Nil => + inside(a.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.typeFullName shouldBe "Foo.Baz" + rhs.typeFullName shouldBe "Foo.Baz" + case _ => fail("Expected 2 arguments under the assignment call") + } + case _ => fail("Expected 1 assignment call.") + } + } + } + "combination of method access expression for chained fields" should { val cpg = code(""" |namespace Foo { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberTests.scala index 49ef9e23ed32..f67ec9219fbe 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberTests.scala @@ -8,73 +8,220 @@ import io.shiftleft.semanticcpg.language.* class MemberTests extends CSharpCode2CpgFixture { - "a basic class declaration" should { - val cpg = code( - """public class Car + "class with static and non-static members" should { + val cpg = code(""" + |class Car |{ - | string color; // field - | static int maxSpeed = 200; // field - | public void fullThrottle() // method - | { - | Console.WriteLine("The car is going as fast as it can!"); - | } - |} - |""".stripMargin, - "Car.cs" - ) + | string color; + | static int maxSpeed = 200; + |}""".stripMargin) + + "have the non-static member correctly set" in { + inside(cpg.member.nameExact("color").l) { + case color :: Nil => + color.typeFullName shouldBe "System.String" + color.code shouldBe "string color" + color.modifier.modifierType.l shouldBe List(ModifierTypes.INTERNAL) + case xs => + fail(s"Expected single `color` member, but got $xs") + } + } - "generate members for fields" in { - val x = cpg.typeDecl.nameExact("Car").head + "have the static member correctly set`" in { + inside(cpg.member.nameExact("maxSpeed").l) { + case maxSpeed :: Nil => + maxSpeed.typeFullName shouldBe "System.Int32" + maxSpeed.code shouldBe "int maxSpeed = 200" + maxSpeed.modifier.modifierType.toSet shouldBe Set(ModifierTypes.INTERNAL, ModifierTypes.STATIC) + case xs => + fail(s"Expected single `maxSpeed` member, but got $xs") + } + } + } + + "class with initialized static member" should { + val cpg = code(""" + |class Car + |{ + | static int nonInitMaxSpeed = 200; + |} + |""".stripMargin) - val color = x.member.nameExact("color").head - color.typeFullName shouldBe "System.String" - color.code shouldBe "string color" - color.modifier.modifierType.l shouldBe ModifierTypes.INTERNAL :: Nil + "have the static member correctly set" in { + inside(cpg.member.nameExact("nonInitMaxSpeed").l) { + case nonInitMaxSpeed :: Nil => + nonInitMaxSpeed.typeFullName shouldBe "System.Int32" + nonInitMaxSpeed.code shouldBe "int nonInitMaxSpeed = 200" + nonInitMaxSpeed.modifier.modifierType.l shouldBe List(ModifierTypes.INTERNAL, ModifierTypes.STATIC) + case xs => + fail(s"Expected single `nonInitMaxSpeed` member, but got $xs") + } + } - val maxSpeed = x.member.nameExact("maxSpeed").head - maxSpeed.typeFullName shouldBe "System.Int32" - maxSpeed.code shouldBe "int maxSpeed = 200" - maxSpeed.modifier.modifierType.l shouldBe ModifierTypes.INTERNAL :: ModifierTypes.STATIC :: Nil + "have a static constructor" in { + inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.StaticInitMethodName).l) { + case cctor :: Nil => + cctor.fullName shouldBe s"Car.${Defines.StaticInitMethodName}:System.Void()" + cctor.modifier.modifierType.toSet shouldBe Set( + ModifierTypes.STATIC, + ModifierTypes.CONSTRUCTOR, + ModifierTypes.INTERNAL + ) + cctor.methodReturn.typeFullName shouldBe "System.Void" + case xs => + fail(s"Expected single static constructor, but got $xs") + } } + "have the static member initialization inside the static constructor" in { + inside(cpg.method.fullNameExact(s"Car.${Defines.StaticInitMethodName}:System.Void()").body.assignment.l) { + case assignment :: Nil => + assignment.target.code shouldBe "nonInitMaxSpeed" + assignment.source.code shouldBe "200" + case xs => + fail(s"Expected single assignment inside the static constructor, but got $xs") + } + } } - "a basic class declaration with a static constructor" should { - val cpg = code( - """public class Car + "class with initialized member" should { + val cpg = code(""" + |class Car |{ - | string color; // field - | static int maxSpeed = 200; // field - | static int nonInitMaxSpeed; // field - | - | public void fullThrottle() // method - | { - | Console.WriteLine("The car is going as fast as it can!"); - | } - | - | static Car() { // static constructor - | nonInitMaxSpeed = 2000; - | } - | + | string color = "red"; |} - |""".stripMargin, - "Car.cs" - ) + |""".stripMargin) - "generate one static constructor" in { + "have the member correctly set" in { + inside(cpg.member.nameExact("color").l) { + case color :: Nil => + color.typeFullName shouldBe "System.String" + color.code shouldBe "string color = \"red\"" + color.modifier.modifierType.l shouldBe List(ModifierTypes.INTERNAL) + case xs => + fail(s"Expected single `color` member, but got $xs") + } + } + + "have a constructor" in { + inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.ConstructorMethodName).l) { + case ctor :: Nil => + ctor.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:System.Void()" + ctor.modifier.modifierType.toSet shouldBe Set(ModifierTypes.INTERNAL, ModifierTypes.CONSTRUCTOR) + ctor.methodReturn.typeFullName shouldBe "System.Void" + case xs => + fail(s"Expected single constructor, but got $xs") + } + } + + "have the member initialization inside the constructor" in { + inside(cpg.method.fullNameExact(s"Car.${Defines.ConstructorMethodName}:System.Void()").body.assignment.l) { + case assignment :: Nil => + assignment.target.code shouldBe "color" + assignment.source.code shouldBe "\"red\"" + case xs => + fail(s"Expected single assignment inside the constructor, but got $xs") + } + } + } + + "class with static constructor" should { + val cpg = code(""" + |class Car + |{ + | static Car() + | { + | } + |}""".stripMargin) + "have a static constructor correctly set" in { inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.StaticInitMethodName).l) { - case m :: Nil => - m.fullName shouldBe s"Car.${Defines.StaticInitMethodName}:void()" - m.modifier.modifierType.l shouldBe ModifierTypes.STATIC :: ModifierTypes.CONSTRUCTOR :: Nil - m.methodReturn.typeFullName shouldBe "void" + case cctor :: Nil => + cctor.fullName shouldBe s"Car.${Defines.StaticInitMethodName}:System.Void()" + cctor.modifier.modifierType.toSet shouldBe Set(ModifierTypes.STATIC, ModifierTypes.CONSTRUCTOR) + cctor.methodReturn.typeFullName shouldBe "System.Void" + case xs => + fail(s"Expected single static constructor, but got $xs") + } + } + } - inside(m.assignment.l) { - case maxSpeed :: nonInitMaxSpeed :: Nil => - maxSpeed.code shouldBe "maxSpeed = 200" - nonInitMaxSpeed.code shouldBe "nonInitMaxSpeed = 2000" - case _ => fail("Exactly 2 assignments expected") - } - case _ => fail("`Car` has no static initializer method") + "class with static constructor and initialized static member" should { + val cpg = code(""" + |class Car + |{ + | static int maxSpeed = 200; + | static Car() + | { + | } + |}""".stripMargin) + "have static member initialization inside static constructor" in { + inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.StaticInitMethodName).body.assignment.l) { + case assignment :: Nil => + assignment.code shouldBe "maxSpeed = 200" + assignment.source.code shouldBe "200" + // TODO: target is currently an identifier. Should it be `Car.maxSpeed` instead? + assignment.target.code shouldBe "maxSpeed" + case xs => + fail(s"Expected single assignment inside static constructor, but got $xs") + } + } + } + + "class with static constructor initializing a member, plus an initialized static member" should { + val cpg = code(""" + |class Car + |{ + | static int maxSpeed = 200; + | static int nonInitMaxSpeed; + | static Car() + | { + | nonInitMaxSpeed = 300; + | } + |} + |""".stripMargin) + "have static constructor with two assignments for initializing the members" in { + inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.StaticInitMethodName).assignment.sortBy(_.code).l) { + case maxSpeedAssignment :: nonInitMaxSpeedAssignment :: Nil => + maxSpeedAssignment.code shouldBe "maxSpeed = 200" + nonInitMaxSpeedAssignment.code shouldBe "nonInitMaxSpeed = 300" + + // TODO: They should have the same representation + maxSpeedAssignment.target.code shouldBe "maxSpeed" + nonInitMaxSpeedAssignment.target.code shouldBe "Car.nonInitMaxSpeed" + case xs => + fail(s"Expected two assignments, but got $xs") + } + } + } + + "class with initialized member and default constructor" should { + val cpg = code(""" + |class Car + |{ + | string color = "red"; + | Car() + | { + | } + |}""".stripMargin) + "have a constructor" in { + inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.ConstructorMethodName).l) { + case ctor :: Nil => + ctor.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:System.Void()" + ctor.modifier.modifierType.toSet shouldBe Set(ModifierTypes.CONSTRUCTOR) + ctor.methodReturn.typeFullName shouldBe "System.Void" + case xs => + fail(s"Expected single constructor, but got $xs") + } + } + + "have the member initialization inside the constructor" in { + inside(cpg.method.fullNameExact(s"Car.${Defines.ConstructorMethodName}:System.Void()").body.assignment.l) { + case assignment :: Nil => + // TODO: test LHS: shouldn't it resemble `this.color`? + assignment.target.code shouldBe "color" + assignment.source.code shouldBe "\"red\"" + case xs => + fail(s"Expected single assignment inside the constructor, but got $xs") } } } @@ -105,9 +252,9 @@ class MemberTests extends CSharpCode2CpgFixture { "generate one constructor" in { inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.ConstructorMethodName).l) { case m :: Nil => - m.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:void()" + m.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:System.Void()" m.modifier.modifierType.l shouldBe ModifierTypes.PUBLIC :: ModifierTypes.CONSTRUCTOR :: Nil - m.methodReturn.typeFullName shouldBe "void" + m.methodReturn.typeFullName shouldBe "System.Void" inside(m.assignment.l) { case color :: initMaxSpeed :: Nil => @@ -146,9 +293,9 @@ class MemberTests extends CSharpCode2CpgFixture { "generate one constructor with necessary parameters" in { inside(cpg.typeDecl.nameExact("Car").method.nameExact(Defines.ConstructorMethodName).l) { case m :: Nil => - m.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:void(System.Int32)" + m.fullName shouldBe s"Car.${Defines.ConstructorMethodName}:System.Void(System.Int32)" m.modifier.modifierType.l shouldBe ModifierTypes.PUBLIC :: ModifierTypes.CONSTRUCTOR :: Nil - m.methodReturn.typeFullName shouldBe "void" + m.methodReturn.typeFullName shouldBe "System.Void" inside(m.assignment.l) { case color :: initMaxSpeed :: Nil => @@ -181,7 +328,7 @@ class MemberTests extends CSharpCode2CpgFixture { | int b; | | static Car() { // static constructor - | this.nonInitMaxSpeed = 2000; + | nonInitMaxSpeed = 2000; | } | | public Car() { @@ -200,7 +347,7 @@ class MemberTests extends CSharpCode2CpgFixture { inside(m.body.astChildren.isCall.l) { case staticImplicit :: staticExplicit :: Nil => staticExplicit.methodFullName shouldBe Operators.assignment - staticExplicit.code shouldBe "this.nonInitMaxSpeed = 2000" + staticExplicit.code shouldBe "nonInitMaxSpeed = 2000" inside(staticExplicit.argument.fieldAccess.l) { case fieldAccess :: Nil => @@ -297,7 +444,9 @@ class MemberTests extends CSharpCode2CpgFixture { } } - "a basic class declaration with a PropertyDeclaration member" should { + // TODO: Getters/Setters are currently being lowered into get_/set_ methods. + // Adapt this unit-test once that is finished. + "a basic class declaration with a PropertyDeclaration member" ignore { val cpg = code(""" |public class Foo { | public int Bar {get; set;} @@ -319,6 +468,28 @@ class MemberTests extends CSharpCode2CpgFixture { } } + "a basic class declaration with a FieldDeclaration member" should { + val cpg = code(""" + |public class Foo { + | public int Bar; + |} + |""".stripMargin) + + "create a member for Bar with appropriate properties" in { + inside(cpg.typeDecl.nameExact("Foo").l) { + case fooClass :: Nil => + inside(fooClass.astChildren.isMember.nameExact("Bar").l) { + case bar :: Nil => + bar.code shouldBe "int Bar" + bar.typeFullName shouldBe "System.Int32" + bar.astParent shouldBe fooClass + case _ => fail("No member named Bar found inside Foo") + } + case _ => fail("No class named Foo found.") + } + } + } + "a member with a external type" should { val cpg = code(""" |using Microsoft.Extensions.Logging; diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MethodTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MethodTests.scala index 474b6a0251ea..577976d310b3 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MethodTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MethodTests.scala @@ -2,6 +2,7 @@ package io.joern.csharpsrc2cpg.querying.ast import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.codepropertygraph.generated.nodes.Return import io.shiftleft.semanticcpg.language.* class MethodTests extends CSharpCode2CpgFixture { @@ -11,9 +12,9 @@ class MethodTests extends CSharpCode2CpgFixture { "generate a method node with type decl parent" in { val x = cpg.method.nameExact("Main").head - x.fullName should startWith("HelloWorld.Program.Main:void") - x.fullName shouldBe "HelloWorld.Program.Main:void(System.String[])" - x.signature shouldBe "void(System.String[])" + x.fullName should startWith("HelloWorld.Program.Main:System.Void") + x.fullName shouldBe "HelloWorld.Program.Main:System.Void(System.String[])" + x.signature shouldBe "System.Void(System.String[])" x.filename shouldBe "Program.cs" x.code shouldBe "static void Main(string[] args)" @@ -77,4 +78,93 @@ class MethodTests extends CSharpCode2CpgFixture { } } + "empty public abstract method" should { + val cpg = code(""" + |abstract class C + |{ + | public abstract void DoStuff(); + |} + |""".stripMargin) + + "have correct modifiers" in { + cpg.method.nameExact("DoStuff").modifier.modifierType.sorted.l shouldBe List( + ModifierTypes.ABSTRACT, + ModifierTypes.PUBLIC + ) + } + } + + "empty protected abstract method" should { + val cpg = code(""" + |abstract class C + |{ + | protected abstract void DoStuff(); + |}""".stripMargin) + + "have correct modifiers" in { + cpg.method.nameExact("DoStuff").modifier.modifierType.sorted.l shouldBe List( + ModifierTypes.ABSTRACT, + ModifierTypes.PROTECTED + ) + } + } + + "standalone method declaration inside a top-level method" should { + val cpg = code(""" + |int MyMain() + |{ + | int MySubMethod() {return 1;} + |} + |""".stripMargin) + + "have correct properties for the nested method" in { + inside(cpg.method.nameExact("MySubMethod").l) { + case sub :: Nil => + sub.fullName shouldBe "Test0_cs_Program.
$.MyMain.MySubMethod:System.Int32()" + sub.signature shouldBe "System.Int32()" + sub.modifier.modifierType.sorted.l shouldBe List(ModifierTypes.INTERNAL) + sub.methodReturn.typeFullName shouldBe "System.Int32" + sub.parentBlock.method.l shouldBe cpg.method.fullNameExact("Test0_cs_Program.
$.MyMain:System.Int32()").l + case xs => fail(s"Expected single MySubMethod METHOD, but got $xs") + } + } + + "have correct body for the nested method" in { + inside(cpg.method.nameExact("MySubMethod").block.astChildren.l) { + case (ret: Return) :: Nil => ret.code shouldBe "return 1;" + case xs => fail(s"Expected single RETURN node, but got $xs") + } + } + } + + "standalone method declaration inside a class method" should { + val cpg = code(""" + |class MyClass + |{ + | int MyMain() + | { + | int MySubMethod() {return 1;} + | } + |} + |""".stripMargin) + + "have correct properties for the nested method" in { + inside(cpg.method.nameExact("MySubMethod").l) { + case sub :: Nil => + sub.fullName shouldBe "MyClass.MyMain.MySubMethod:System.Int32()" + sub.signature shouldBe "System.Int32()" + sub.modifier.modifierType.sorted.l shouldBe List(ModifierTypes.INTERNAL) + sub.methodReturn.typeFullName shouldBe "System.Int32" + sub.parentBlock.method.l shouldBe cpg.method.fullNameExact("MyClass.MyMain:System.Int32()").l + case xs => fail(s"Expected single MySubMethod METHOD, but got $xs") + } + } + + "have correct body for the nested method" in { + inside(cpg.method.nameExact("MySubMethod").block.astChildren.l) { + case (ret: Return) :: Nil => ret.code shouldBe "return 1;" + case xs => fail(s"Expected single RETURN node, but got $xs") + } + } + } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ObjectCreationTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ObjectCreationTests.scala new file mode 100644 index 000000000000..ffafea6e3911 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ObjectCreationTests.scala @@ -0,0 +1,48 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.generated.nodes.Identifier +import io.shiftleft.semanticcpg.language.* + +class ObjectCreationTests extends CSharpCode2CpgFixture { + + "assignment to an object creation for a known class" should { + val cpg = code(""" + |using System.Text; + |var x = new StringBuilder(); + |""".stripMargin) + + "have correct constructor call properties" in { + inside(cpg.call.nameExact(Defines.ConstructorMethodName).headOption) { + case Some(ctor) => + ctor.typeFullName shouldBe "System.Text.StringBuilder" + ctor.methodFullName shouldBe "System.Text.StringBuilder." + case None => fail(s"Expected a constructor call") + } + } + + "have correct typeFullName for the assigned variable" in { + cpg.assignment.target.isIdentifier.nameExact("x").typeFullName.l shouldBe List("System.Text.StringBuilder") + } + } + + "assignment to a fully-qualified object creation for a known class" should { + val cpg = code(""" + |var x = new System.Text.StringBuilder(); + |""".stripMargin) + + "have correct constructor call properties" in { + inside(cpg.call.nameExact(Defines.ConstructorMethodName).headOption) { + case Some(ctor) => + ctor.typeFullName shouldBe "System.Text.StringBuilder" + ctor.methodFullName shouldBe "System.Text.StringBuilder." + case None => fail(s"Expected a constructor CALL") + } + } + + "have correct typeFullName for the assigned variable" in { + cpg.assignment.target.isIdentifier.nameExact("x").typeFullName.l shouldBe List("System.Text.StringBuilder") + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertyGetterTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertyGetterTests.scala new file mode 100644 index 000000000000..ffe35373dd18 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertyGetterTests.scala @@ -0,0 +1,306 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, ModifierTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} +import io.shiftleft.semanticcpg.language.* + +class PropertyGetterTests extends CSharpCode2CpgFixture { + + "`System.Console.Out` being assigned to a variable" should { + val cpg = code(""" + |using System; + |var x = System.Console.Out; + |""".stripMargin) + + "have variable correctly typed" in { + cpg.identifier.nameExact("x").typeFullName.l shouldBe List("System.IO.TextWriter") + } + + "have System.Console.Out correctly set" in { + inside(cpg.call.code("System.Console.Out").l) { + case consoleOut :: Nil => + consoleOut.name shouldBe "get_Out" + consoleOut.methodFullName shouldBe "System.Console.get_Out:System.IO.TextWriter()" + consoleOut.argument shouldBe empty + consoleOut.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + consoleOut.typeFullName shouldBe "System.IO.TextWriter" + case xs => fail(s"Expected single call for System.Console.Out, but got $xs") + } + } + } + + "`System.Console.Out.WriteLine` call" should { + val cpg = code(""" + |using System; + |using System.IO; + |System.Console.Out.WriteLine("X"); + |""".stripMargin) + + "have correct properties for WriteLine" in { + inside(cpg.call.nameExact("WriteLine").l) { + case writeLine :: Nil => + writeLine.code shouldBe "System.Console.Out.WriteLine(\"X\")" + writeLine.methodFullName shouldBe "System.IO.TextWriter.WriteLine:System.Void(System.String)" + writeLine.typeFullName shouldBe "System.Void" + case xs => fail(s"Expected single WriteLine call, but got $xs") + } + } + + "have correct arguments for WriteLine" in { + inside(cpg.call.nameExact("WriteLine").argument.sortBy(_.argumentIndex).l) { + case (receiver: Call) :: (literal: Literal) :: Nil => + receiver.argumentIndex shouldBe 0 + receiver.code shouldBe "System.Console.Out" + receiver.name shouldBe "get_Out" + receiver.typeFullName shouldBe "System.IO.TextWriter" + + literal.argumentIndex shouldBe 1 + literal.code shouldBe "\"X\"" + literal.typeFullName shouldBe "System.String" + case xs => fail(s"Expected two arguments for WriteLine, but got $xs") + } + } + + "have correct properties for System.Console.Out" in { + inside(cpg.call.code("System.Console.Out").l) { + case out :: Nil => + out.name shouldBe "get_Out" + out.typeFullName shouldBe "System.IO.TextWriter" + out.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + out.methodFullName shouldBe "System.Console.get_Out:System.IO.TextWriter()" + case xs => + fail(s"Expected single call for System.Console.Out, but got $xs") + } + } + + "have correct arguments for System.Console.Out" in { + cpg.call.code("System.Console.out").argument shouldBe empty + } + } + + // FIXME: Fails because without System.IO in scope, System.Console.Out can't be found. + "`System.Console.Out.WriteLine` call without importing `System.IO`" ignore { + val cpg = code(""" + |System.Console.Out.WriteLine("X"); + |""".stripMargin) + + "have correct properties for WriteLine" in { + inside(cpg.call.nameExact("WriteLine").l) { + case writeLine :: Nil => + writeLine.code shouldBe "System.Console.Out.WriteLine(\"X\")" + writeLine.methodFullName shouldBe "System.IO.TextWriter.WriteLine:System.Void(System.String)" + writeLine.typeFullName shouldBe "System.Void" + case xs => fail(s"Expected single WriteLine call, but got $xs") + } + } + + "have correct arguments for WriteLine" in { + inside(cpg.call.nameExact("WriteLine").argument.sortBy(_.argumentIndex).l) { + case (receiver: Call) :: (literal: Literal) :: Nil => + receiver.argumentIndex shouldBe 0 + receiver.code shouldBe "System.Console.Out" + receiver.name shouldBe "get_Out" + receiver.typeFullName shouldBe "System.IO.TextWriter" + + literal.argumentIndex shouldBe 1 + literal.code shouldBe "\"X\"" + literal.typeFullName shouldBe "System.String" + case xs => fail(s"Expected two arguments for WriteLine, but got $xs") + } + } + + "have correct properties for System.Console.Out" in { + inside(cpg.call.code("System.Console.Out").l) { + case out :: Nil => + out.name shouldBe "get_Out" + out.typeFullName shouldBe "System.IO.TextWriter" + out.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + out.methodFullName shouldBe "System.Console.get_Out:System.IO.TextWriter()" + case xs => + fail(s"Expected single call for System.Console.Out, but got $xs") + } + } + + "have correct arguments for System.Console.Out" in { + cpg.call.code("System.Console.out").argument shouldBe empty + } + } + + "`ConsoleKeyInfo.KeyChar` being assigned to a variable" should { + val cpg = code(""" + |using System; + |var x = new ConsoleKeyInfo(); + |var y = x.KeyChar; + |""".stripMargin) + + "have variable correctly typed" in { + cpg.assignment.target.isIdentifier.nameExact("x").typeFullName.l shouldBe List("System.ConsoleKeyInfo") + cpg.assignment.target.isIdentifier.nameExact("y").typeFullName.l shouldBe List("System.Char") + } + + "have correct properties for KeyChar call" in { + inside(cpg.call.code("x.KeyChar").l) { + case keyChar :: Nil => + keyChar.name shouldBe "get_KeyChar" + keyChar.methodFullName shouldBe "System.ConsoleKeyInfo.get_KeyChar:System.Char(System.ConsoleKeyInfo)" + keyChar.typeFullName shouldBe "System.Char" + keyChar.signature shouldBe "System.Char(System.ConsoleKeyInfo)" + case xs => fail(s"Expected single call to KeyChar, but got $xs") + } + } + + "have correct arguments for KeyChar call" in { + inside(cpg.call.code("x.KeyChar").argument.sortBy(_.argumentIndex).l) { + case (x: Identifier) :: Nil => + x.typeFullName shouldBe "System.ConsoleKeyInfo" + x.code shouldBe "x" + x.name shouldBe "x" + x.argumentIndex shouldBe 0 + case xs => fail(s"Expected single identifier argument to KeyChar, but got $xs") + } + } + } + + "uninitialized get-only property declaration" should { + val cpg = code(""" + |class C + |{ + | public int MyProperty { get; } + |} + |""".stripMargin) + + "be lowered into a get_* method" in { + inside(cpg.method.nameExact("get_MyProperty").l) { + case method :: Nil => + method.fullName shouldBe "C.get_MyProperty:System.Int32(C)" + method.signature shouldBe "System.Int32(C)" + case xs => fail(s"Expected single get_MyProperty method, but got $xs") + } + } + + "have correct modifiers" in { + cpg.method.nameExact("get_MyProperty").modifier.modifierType.sorted.l shouldBe List(ModifierTypes.PUBLIC) + } + + "have correct parameters" in { + inside(cpg.method.nameExact("get_MyProperty").parameter.l) { + case thisParam :: Nil => + thisParam.typeFullName shouldBe "C" + thisParam.name shouldBe "this" + case xs => fail(s"Expected this parameter for get_MyProperty, but got $xs") + } + } + + "have empty body" in { + cpg.method.nameExact("get_MyProperty").body.astChildren shouldBe empty + } + } + + "assignment whose RHS is a get-only property declared in the source-code" should { + val cpg = code(""" + |class C { public int MyProperty {get;} } + |class M + |{ + | void Run() + | { + | var c = new C(); + | var x = c.MyProperty; + | } + |} + |""".stripMargin) + + "have a get_* method call on the RHS" in { + inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).source.l) { + case (rhs: Call) :: Nil => + rhs.code shouldBe "c.MyProperty" + rhs.name shouldBe "get_MyProperty" + rhs.methodFullName shouldBe "C.get_MyProperty:System.Int32(C)" + rhs.typeFullName shouldBe "System.Int32" + rhs.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single RHS call for the assignment of x, but got $xs") + } + } + + "have correct arguments to the get_* call" in { + inside(cpg.call.codeExact("c.MyProperty").argument.sortBy(_.argumentIndex).l) { + case (baseArg: Identifier) :: Nil => + baseArg.argumentIndex shouldBe 0 + baseArg.code shouldBe "c" + baseArg.typeFullName shouldBe "C" + case xs => fail(s"Expected single identifier argument to c.MyProperty, but got $xs") + } + } + + "have correct typeFullName for the assignment" in { + cpg.assignment.where(_.target.isIdentifier.nameExact("x")).typeFullName.l shouldBe List("System.Int32") + } + } + + "assignment whose RHS is a static get-only property declared in the source-code" should { + val cpg = code(""" + |class C { public static int MyProperty {get;} } + |class M + |{ + | void Run() + | { + | var c = new C(); + | var x = c.MyProperty; + | } + |} + |""".stripMargin) + + "have a get_* method call on the RHS" in { + inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).source.l) { + case (rhs: Call) :: Nil => + rhs.code shouldBe "c.MyProperty" + rhs.name shouldBe "get_MyProperty" + rhs.methodFullName shouldBe "C.get_MyProperty:System.Int32()" + rhs.typeFullName shouldBe "System.Int32" + rhs.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Expected single RHS call for the assignment of x, but got $xs") + } + } + + "have correct arguments to the get_* call" in { + cpg.call.codeExact("c.MyProperty").argument shouldBe empty + } + + "have correct typeFullName for the assignment" in { + cpg.assignment.where(_.target.isIdentifier.nameExact("x")).typeFullName.l shouldBe List("System.Int32") + } + } + + "uninitialized static get-only property declaration" should { + val cpg = code(""" + |public class C + |{ + | public static string MyProperty { get; } + |} + |""".stripMargin) + + "be lowered into a get_* method" in { + inside(cpg.method.nameExact("get_MyProperty").l) { + case method :: Nil => + method.fullName shouldBe "C.get_MyProperty:System.String()" + method.signature shouldBe "System.String()" + case xs => fail(s"Expected single get_MyProperty method, but got $xs") + } + } + + "have correct modifiers" in { + cpg.method.nameExact("get_MyProperty").modifier.modifierType.sorted.l shouldBe List( + ModifierTypes.PUBLIC, + ModifierTypes.STATIC + ) + } + + "have no parameters" in { + cpg.method.nameExact("get_MyProperty").parameter shouldBe empty + } + + "have empty body" in { + cpg.method.nameExact("get_MyProperty").body.astChildren shouldBe empty + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertySetterTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertySetterTests.scala new file mode 100644 index 000000000000..c7b3fa2de94f --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PropertySetterTests.scala @@ -0,0 +1,508 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, ModifierTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} +import io.shiftleft.semanticcpg.language.* + +class PropertySetterTests extends CSharpCode2CpgFixture { + + "uninitialized set-only property declaration" should { + val cpg = code(""" + |using System; + |class C + |{ + | public int MyProperty { set { Console.WriteLine(value); } } + |} + |""".stripMargin) + + "be lowered into a set_* method" in { + inside(cpg.method.nameExact("set_MyProperty").l) { + case method :: Nil => + method.fullName shouldBe "C.set_MyProperty:void(C,System.Int32)" + method.signature shouldBe "void(C,System.Int32)" + case xs => fail(s"Expected single set_MyProperty method, but got $xs") + } + } + + "have correct modifiers" in { + cpg.method.nameExact("set_MyProperty").modifier.modifierType.sorted.l shouldBe List(ModifierTypes.PUBLIC) + } + + "have correct parameters" in { + inside(cpg.method.nameExact("set_MyProperty").parameter.sortBy(_.index).l) { + case thisArg :: valueArg :: Nil => + thisArg.index shouldBe 0 + thisArg.name shouldBe "this" + thisArg.typeFullName shouldBe "C" + + valueArg.index shouldBe 1 + valueArg.name shouldBe "value" + valueArg.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two arguments to set_MyProperty, but got $xs") + } + } + + "have correct body" in { + inside(cpg.method.nameExact("set_MyProperty").body.flatMap(_.astChildren).l) { + case (writeLine: Call) :: Nil => + writeLine.code shouldBe "Console.WriteLine(value)" + writeLine.methodFullName shouldBe "System.Console.WriteLine:System.Void(System.Boolean)" + case xs => fail(s"Expected single node inside set_MyProperty's body, but got $xs") + } + } + } + + "uninitialized static set-only property declaration" should { + val cpg = code(""" + |class C + |{ + | public static int MyProperty { set { } } + |} + |""".stripMargin) + + "be lowered into a set_* method" in { + inside(cpg.method.nameExact("set_MyProperty").l) { + case method :: Nil => + method.fullName shouldBe "C.set_MyProperty:void(System.Int32)" + method.signature shouldBe "void(System.Int32)" + case xs => fail(s"Expected single set_MyProperty method, but got $xs") + } + } + + "have correct modifiers" in { + cpg.method.nameExact("set_MyProperty").modifier.modifierType.sorted.l shouldBe List( + ModifierTypes.PUBLIC, + ModifierTypes.STATIC + ) + } + + "have correct parameters" in { + inside(cpg.method.nameExact("set_MyProperty").parameter.sortBy(_.index).l) { + case valueArg :: Nil => + valueArg.index shouldBe 1 + valueArg.name shouldBe "value" + valueArg.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two arguments to set_MyProperty, but got $xs") + } + } + + "have correct body" in { + cpg.method.nameExact("set_MyProperty").body.flatMap(_.astChildren) shouldBe empty + } + } + + "setting a previously declared {set{}} property via `x.Property = y` where `x` is a local variable" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { set {} } + |} + |class Main + |{ + | public static void DoStuff() + | { + | var m = new MyData(); + | m.MyProperty = 3; // rendered as MyData.set_MyProperty(m, 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "m.MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (m: Identifier) :: (three: Literal) :: Nil => + m.typeFullName shouldBe "MyData" + m.code shouldBe "m" + m.argumentIndex shouldBe 0 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared {get;set;} property via `x.Property = y` where `x` is method parameter" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { get; set; } + |} + |class Main + |{ + | public static void DoStuff(MyData m) + | { + | m.MyProperty = 3; // rendered as MyData.set_MyProperty(m, 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "m.MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (m: Identifier) :: (three: Literal) :: Nil => + m.typeFullName shouldBe "MyData" + m.code shouldBe "MyData m" + m.argumentIndex shouldBe 0 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared {get;set;} property via `this.Property = y`" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { get; set; } + | public void DoStuff() + | { + | this.MyProperty = 3; // rendered as MyData.set_MyProperty(this, 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "this.MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (thisId: Identifier) :: (three: Literal) :: Nil => + thisId.typeFullName shouldBe "MyData" + thisId.code shouldBe "this" + thisId.argumentIndex shouldBe 0 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared {get;set;} property via `Property = y`" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { get; set; } + | public void DoStuff() + | { + | MyProperty = 3; // rendered as MyData.set_MyProperty(this, 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (thisId: Identifier) :: (three: Literal) :: Nil => + thisId.typeFullName shouldBe "MyData" + thisId.code shouldBe "this" + thisId.argumentIndex shouldBe 0 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared {get;set;} property via `Property *= y" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { get; set; } + | public void DoStuff() + | { + | MyProperty *= 3; // rendered as set_MyProperty(get_MyProperty() * 3); + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "MyProperty *= 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (thisId: Identifier) :: (timesCall: Call) :: Nil => + thisId.typeFullName shouldBe "MyData" + thisId.code shouldBe "this" + thisId.argumentIndex shouldBe 0 + + timesCall.argumentIndex shouldBe 1 + timesCall.code shouldBe "MyProperty *= 3" + timesCall.methodFullName shouldBe Operators.multiplication + timesCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + + "have correct arguments to the synthetic `*` call" in { + inside(cpg.call.nameExact("set_MyProperty").argument(1).isCall.argument.sortBy(_.argumentIndex).l) { + case (getter: Call) :: (three: Literal) :: Nil => + getter.argumentIndex shouldBe 1 + getter.code shouldBe "MyProperty *= 3" + getter.methodFullName shouldBe "MyData.get_MyProperty:System.Int32()" + getter.name shouldBe "get_MyProperty" + getter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + + three.argumentIndex shouldBe 2 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two arguments for +, but got $xs") + } + } + + "have correct arguments to the synthetic getter call" in { + inside(cpg.call.nameExact("get_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (receiver: Identifier) :: Nil => + receiver.argumentIndex shouldBe 0 + receiver.typeFullName shouldBe "MyData" + receiver.code shouldBe "this" + receiver.name shouldBe "this" + case xs => fail(s"Expected single argument to get_MyProperty, but got $xs") + } + } + } + + "setting a previously declared static {get;set;} property via `Property = y`" should { + val cpg = code(""" + |class MyData + |{ + | public static int MyProperty { get; set; } + | public void DoStuff() + | { + | MyProperty = 3; // rendered as MyData.set_MyProperty(3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (three: Literal) :: Nil => + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared static {set{}} property via `C.Property = y` where `C` is the class name" should { + val cpg = code(""" + |class MyData + |{ + | public static int MyProperty { set{} } + |} + |class Main + |{ + | public static void DoStuff() + | { + | MyData.MyProperty = 3; // rendered as MyData.set_MyProperty(3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "MyData.MyProperty = 3" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (three: Literal) :: Nil => + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + three.argumentIndex shouldBe 1 + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + } + + "setting a previously declared {get;set;} property via `x.Property += y` where `x` is a local variable" should { + val cpg = code(""" + |class MyData + |{ + | public int MyProperty { get; set; } + |} + |class Main + |{ + | public static void DoStuff() + | { + | var m = new MyData(); + | m.MyProperty += 3; // rendered as MyData.set_MyProperty(m, MyData.get_MyProperty() + 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "m.MyProperty += 3" + setter.name shouldBe "set_MyProperty" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (m: Identifier) :: (plusCall: Call) :: Nil => + m.typeFullName shouldBe "MyData" + m.code shouldBe "m" + m.argumentIndex shouldBe 0 + + plusCall.argumentIndex shouldBe 1 + plusCall.code shouldBe "m.MyProperty += 3" + plusCall.methodFullName shouldBe Operators.plus + plusCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + + "have correct arguments to the synthetic `+` call" in { + inside(cpg.call.nameExact("set_MyProperty").argument(1).isCall.argument.sortBy(_.argumentIndex).l) { + case (getter: Call) :: (three: Literal) :: Nil => + getter.argumentIndex shouldBe 1 + getter.code shouldBe "m.MyProperty += 3" + getter.methodFullName shouldBe "MyData.get_MyProperty:System.Int32()" + getter.name shouldBe "get_MyProperty" + getter.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + + three.argumentIndex shouldBe 2 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two arguments for +, but got $xs") + } + } + + "have correct arguments to the synthetic getter call" in { + inside(cpg.call.nameExact("get_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (receiver: Identifier) :: Nil => + receiver.argumentIndex shouldBe 0 + receiver.typeFullName shouldBe "MyData" + receiver.code shouldBe "m" + receiver.name shouldBe "m" + case xs => fail(s"Expected single argument to get_MyProperty, but got $xs") + } + } + } + + "setting a previously declared static {get;set;} property via `C.Property += y` where `C` is the class name" should { + val cpg = code(""" + |class MyData + |{ + | public static int MyProperty { get; set; } + |} + |class Main + |{ + | public static void DoStuff() + | { + | MyData.MyProperty += 3; // rendered as MyData.set_MyProperty(MyData.get_MyProperty() + 3) + | } + |} + |""".stripMargin) + + "be translated to that property's set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").l) { + case setter :: Nil => + setter.code shouldBe "MyData.MyProperty += 3" + setter.name shouldBe "set_MyProperty" + setter.methodFullName shouldBe "MyData.set_MyProperty:System.Void(System.Int32)" + setter.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Expected single call to set_MyProperty, but got $xs") + } + } + + "have correct arguments to the set_* method" in { + inside(cpg.call.nameExact("set_MyProperty").argument.sortBy(_.argumentIndex).l) { + case (plusCall: Call) :: Nil => + plusCall.argumentIndex shouldBe 1 + plusCall.code shouldBe "MyData.MyProperty += 3" + plusCall.methodFullName shouldBe Operators.plus + plusCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail(s"Unexpected arguments to set_MyProperty, got $xs") + } + } + + "have correct arguments to the synthetic `+` call" in { + inside(cpg.call.nameExact("set_MyProperty").argument(1).isCall.argument.sortBy(_.argumentIndex).l) { + case (getter: Call) :: (three: Literal) :: Nil => + getter.argumentIndex shouldBe 1 + getter.code shouldBe "MyData.MyProperty += 3" + getter.methodFullName shouldBe "MyData.get_MyProperty:System.Int32()" + getter.name shouldBe "get_MyProperty" + getter.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + three.argumentIndex shouldBe 2 + three.code shouldBe "3" + three.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two arguments for +, but got $xs") + } + } + + "have correct arguments to the synthetic getter call" in { + cpg.call.nameExact("get_MyProperty").argument shouldBe empty + } + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TopLevelStatementTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TopLevelStatementTests.scala new file mode 100644 index 000000000000..0e223a191a14 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TopLevelStatementTests.scala @@ -0,0 +1,80 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.semanticcpg.language.* + +class TopLevelStatementTests extends CSharpCode2CpgFixture { + + "WriteLine as a top-level statement should be found inside a fictitious method" in { + val cpg = code(""" + |using System; + |Console.WriteLine("Foo"); + |""".stripMargin) + + inside(cpg.call("WriteLine").method.l) { + case method :: Nil => + method.fullName shouldBe "Test0_cs_Program.
$:System.Void(System.String[])" + method.signature shouldBe "System.Void(System.String[])" + method.typeDecl.l shouldBe cpg.typeDecl("Test0_cs_Program").l + case xs => fail(s"Expected a method above WriteLine, but found $xs") + } + } + + "fictitious class name when found inside a directory" in { + val cpg = code( + """ + |System.Console.WriteLine(args); + |""".stripMargin, + "MyProject/Main.cs" + ) + inside(cpg.call("WriteLine").method.l) { + case method :: Nil => + method.fullName shouldBe "MyProject_Main_cs_Program.
$:System.Void(System.String[])" + method.signature shouldBe "System.Void(System.String[])" + method.typeDecl.l shouldBe cpg.typeDecl("MyProject_Main_cs_Program").l + case xs => fail(s"Expected a method above WriteLine, but found $xs") + } + } + + "free-variable `args` is in fact a method parameter" in { + val cpg = code(""" + |System.Console.WriteLine(args); + |""".stripMargin) + inside(cpg.parameter("args").l) { + case args :: Nil => + args.typeFullName shouldBe "System.String[]" + args.method.fullName shouldBe "Test0_cs_Program.
$:System.Void(System.String[])" + case xs => fail(s"Expected single parameter named `args`, but found $xs") + } + } + + "class declaration defined after top-level statements is present" in { + val cpg = code(""" + |System.Console.WriteLine(args); + |class XYZ + |{ + |}""".stripMargin) + inside(cpg.typeDecl("XYZ").l) { + case xyz :: Nil => xyz.fullName shouldBe "XYZ" + case xs => fail(s"Expected single TYPE_DECL named `XYZ`, but found $xs") + } + } + + "top-level method becomes an inner static local method to the fictitious main" in { + val cpg = code(""" + |void Run() {} + |""".stripMargin) + inside(cpg.method.nameExact("Run").l) { + case run :: Nil => + run.methodReturn.typeFullName shouldBe "System.Void" + run.fullName shouldBe "Test0_cs_Program.
$.Run:System.Void()" + run.modifier.modifierType.toSet shouldBe Set(ModifierTypes.STATIC, ModifierTypes.INTERNAL) + run.parentBlock.method.l shouldBe cpg.method + .fullNameExact("Test0_cs_Program.
$:System.Void(System.String[])") + .l + case xs => + fail(s"Expected single METHOD named Run, but found $xs") + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TypeDeclTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TypeDeclTests.scala index a616d2207e26..7a9e4f38b706 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TypeDeclTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/TypeDeclTests.scala @@ -288,16 +288,16 @@ class TypeDeclTests extends CSharpCode2CpgFixture { "create a TypeDecl node" in { inside(cpg.method("Main").astChildren.isTypeDecl.l) { case anonType :: Nil => - anonType.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anonType.fullName shouldBe "HelloWorld.Program.Main.0" anonType.astParentType shouldBe "METHOD" - anonType.astParentFullName shouldBe "HelloWorld.Program.Main:void(System.String[])" + anonType.astParentFullName shouldBe "HelloWorld.Program.Main:System.Void(System.String[])" case _ => fail("No TypeDecl node for anonymous object found") } } "propagate type to the LHS" in { inside(cpg.method("Main").astChildren.isBlock.astChildren.isLocal.nameExact("Foo").l) { case loc :: Nil => - loc.typeFullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + loc.typeFullName shouldBe "HelloWorld.Program.Main.0" } } @@ -338,16 +338,16 @@ class TypeDeclTests extends CSharpCode2CpgFixture { "create a TypeDecl node" in { inside(cpg.method("Main").astChildren.isTypeDecl.l) { case anonType :: Nil => - anonType.fullName shouldBe "Foo.Bar.Main:void().0" + anonType.fullName shouldBe "Foo.Bar.Main.0" anonType.astParentType shouldBe "METHOD" - anonType.astParentFullName shouldBe "Foo.Bar.Main:void()" + anonType.astParentFullName shouldBe "Foo.Bar.Main:System.Void()" case _ => fail("No TypeDecl node for anonymous object found") } } "propagate type to the LHS" in { inside(cpg.method("Main").astChildren.isBlock.astChildren.isLocal.nameExact("Fred").l) { case loc :: Nil => - loc.typeFullName shouldBe "Foo.Bar.Main:void().0" + loc.typeFullName shouldBe "Foo.Bar.Main.0" } } @@ -380,13 +380,13 @@ class TypeDeclTests extends CSharpCode2CpgFixture { "have correct attributes" in { inside(cpg.method("Main").astChildren.isTypeDecl.l) { case anonType :: anonType2 :: Nil => - anonType.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" + anonType.fullName shouldBe "HelloWorld.Program.Main.0" anonType.astParentType shouldBe "METHOD" - anonType.astParentFullName shouldBe "HelloWorld.Program.Main:void(System.String[])" + anonType.astParentFullName shouldBe "HelloWorld.Program.Main:System.Void(System.String[])" - anonType2.fullName shouldBe "HelloWorld.Program.Main:void(System.String[]).1" + anonType2.fullName shouldBe "HelloWorld.Program.Main.1" anonType2.astParentType shouldBe "METHOD" - anonType2.astParentFullName shouldBe "HelloWorld.Program.Main:void(System.String[])" + anonType2.astParentFullName shouldBe "HelloWorld.Program.Main:System.Void(System.String[])" case _ => fail("There should be exactly 2 anonymous types present") } } @@ -394,8 +394,8 @@ class TypeDeclTests extends CSharpCode2CpgFixture { "propagate type to the LHS" in { inside(cpg.method("Main").astChildren.isBlock.astChildren.isLocal.l) { case loc :: loc2 :: Nil => - loc.typeFullName shouldBe "HelloWorld.Program.Main:void(System.String[]).0" - loc2.typeFullName shouldBe "HelloWorld.Program.Main:void(System.String[]).1" + loc.typeFullName shouldBe "HelloWorld.Program.Main.0" + loc2.typeFullName shouldBe "HelloWorld.Program.Main.1" case _ => fail("Exactly two locals should be present") } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/UsingDirectiveTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/UsingDirectiveTests.scala new file mode 100644 index 000000000000..73a55bec1a98 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/UsingDirectiveTests.scala @@ -0,0 +1,59 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class UsingDirectiveTests extends CSharpCode2CpgFixture { + + "`global using` directive in another file" should { + val cpg = code(""" + |class Foo + |{ + | static void Run() + | { + | Console.WriteLine("Hello"); + | } + |}""".stripMargin) + .moreCode( + """ + |global using System; + |""".stripMargin, + "globals.cs" + ) + + "make the imported namespace available in the current file" in { + inside(cpg.call("WriteLine").l) { + case writeLine :: Nil => + writeLine.methodFullName shouldBe "System.Console.WriteLine:System.Void(System.String)" + case xs => + fail(s"Expected single WriteLine call, but found $xs") + } + } + } + + "`using` directive in another file" should { + val cpg = code(""" + |class Foo + |{ + | static void Run() + | { + | Console.WriteLine("Hello"); + | } + |}""".stripMargin) + .moreCode( + """ + |using System; + |""".stripMargin, + "dummy.cs" + ) + + "not affect the imported namespaces in the current file" in { + inside(cpg.call("WriteLine").l) { + case writeLine :: Nil => + writeLine.methodFullName shouldBe ".WriteLine:" + case xs => + fail(s"Expected single WriteLine call, but found $xs") + } + } + } +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/ControlStructureDataflowTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/ControlStructureDataflowTests.scala index 9f086f405db0..8d2c5d2e2a66 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/ControlStructureDataflowTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/ControlStructureDataflowTests.scala @@ -38,7 +38,7 @@ class ControlStructureDataflowTests extends CSharpCode2CpgFixture(withDataFlow = "find a path from element to Write and from i to assignment through a foreach loop" in { val elementSrc = cpg.identifier.nameExact("element").l val writeSink = cpg.call.nameExact("Write").l - writeSink.reachableBy(elementSrc).size shouldBe 1 + writeSink.reachableBy(elementSrc).size shouldBe 2 val assignmentSrc = cpg.identifier.nameExact("i").lineNumber(10).l val newI = cpg.identifier.nameExact("newI").l diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/OperatorDataflowTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/OperatorDataflowTests.scala index 15771fc038b8..674120c94020 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/OperatorDataflowTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/dataflow/OperatorDataflowTests.scala @@ -17,6 +17,7 @@ class OperatorDataflowTests extends CSharpCode2CpgFixture(withDataFlow = true) { val src = cpg.identifier.nameExact("a").l; val sink = cpg.identifier.nameExact("d").l sink.reachableBy(src).size shouldBe 2 + sink.reachableBy(cpg.literal("3")).size shouldBe 1 } "be reachable (case 2)" in { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala index 9b31f150451f..a410d8a0ff06 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala @@ -1,8 +1,9 @@ package io.joern.csharpsrc2cpg.testfixtures import io.joern.csharpsrc2cpg.{CSharpSrc2Cpg, Config} +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.Path -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} import io.joern.x2cpg.testfixtures.{Code2CpgFixture, DefaultTestCpg, LanguageFrontend} import io.joern.x2cpg.{ValidationMode, X2Cpg} @@ -16,14 +17,14 @@ import java.io.File class CSharpCode2CpgFixture( withPostProcessing: Boolean = false, withDataFlow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty + semantics: Semantics = DefaultSemantics() ) extends Code2CpgFixture(() => new DefaultTestCpgWithCSharp() .withOssDataflow(withDataFlow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) - with SemanticCpgTestFixture(extraFlows) + with SemanticCpgTestFixture(semantics) with Inside { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Ghidra2Cpg.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Ghidra2Cpg.scala index 751ec5c511fe..02c66985fa8a 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Ghidra2Cpg.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Ghidra2Cpg.scala @@ -13,7 +13,7 @@ import ghidra.program.model.listing.Program import ghidra.program.util.{DefinedDataIterator, GhidraProgramUtilities} import ghidra.util.exception.InvalidInputException import ghidra.util.task.TaskMonitor -import io.joern.ghidra2cpg.passes._ +import io.joern.ghidra2cpg.passes.* import io.joern.ghidra2cpg.passes.arm.ArmFunctionPass import io.joern.ghidra2cpg.passes.mips.{LoHiPass, MipsFunctionPass} import io.joern.ghidra2cpg.passes.x86.{ReturnEdgesPass, X86FunctionPass} @@ -26,7 +26,7 @@ import utilities.util.FileUtilities import java.io.File import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Try class Ghidra2Cpg extends X2CpgFrontend[Config] { @@ -149,9 +149,6 @@ class Ghidra2Cpg extends X2CpgFrontend[Config] { new LiteralPass(cpg, flatProgramAPI).createAndApply() } - private class HeadlessProjectConnection(projectManager: HeadlessGhidraProjectManager, connection: GhidraURLConnection) - extends DefaultProject(projectManager, connection) {} - private class HeadlessGhidraProjectManager extends DefaultProjectManager {} } diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Main.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Main.scala index 3c6ee6139f2f..189f3f8fd034 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Main.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/Main.scala @@ -1,6 +1,6 @@ package io.joern.ghidra2cpg -import io.joern.ghidra2cpg.Frontend._ +import io.joern.ghidra2cpg.Frontend.* import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} import scopt.OParser diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/FunctionPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/FunctionPass.scala index 4af29997e1d5..c32880c96990 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/FunctionPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/FunctionPass.scala @@ -6,17 +6,17 @@ import ghidra.program.model.lang.Register import ghidra.program.model.listing.{CodeUnitFormat, CodeUnitFormatOptions, Function, Instruction, Program} import ghidra.program.model.pcode.{HighFunction, HighSymbol} import ghidra.program.model.scalar.Scalar -import io.joern.ghidra2cpg._ -import io.joern.ghidra2cpg.processors._ +import io.joern.ghidra2cpg.* +import io.joern.ghidra2cpg.processors.* import io.joern.ghidra2cpg.utils.Decompiler -import io.joern.ghidra2cpg.utils.Utils._ +import io.joern.ghidra2cpg.utils.Utils.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{CfgNodeNew, NewBlock, NewMethod} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.passes.ForkJoinParallelCpgPass import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions abstract class FunctionPass( @@ -60,7 +60,7 @@ abstract class FunctionPass( override def generateParts(): Array[Function] = functions.toArray - implicit def intToIntegerOption(intOption: Option[Int]): Option[Integer] = intOption.map(intValue => { + implicit def intToIntegerOption(intOption: Option[Int]): Option[Int] = intOption.map(intValue => { val integerValue = intValue integerValue }) diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/JumpPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/JumpPass.scala index dfbd5fa2beca..7a58ae23608e 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/JumpPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/JumpPass.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.util.Try diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/LiteralPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/LiteralPass.scala index 97e27d06c0b9..97f37248b57a 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/LiteralPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/LiteralPass.scala @@ -6,7 +6,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.passes.ForkJoinParallelCpgPass -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions class LiteralPass(cpg: Cpg, flatProgramAPI: FlatProgramAPI) extends ForkJoinParallelCpgPass[String](cpg) { diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/PCodePass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/PCodePass.scala index 7905e95c11df..de8fca302cdf 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/PCodePass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/PCodePass.scala @@ -2,15 +2,14 @@ package io.joern.ghidra2cpg.passes import ghidra.program.model.listing.{Function, Program} import ghidra.program.util.DefinedDataIterator -import io.joern.ghidra2cpg._ -import io.joern.ghidra2cpg.utils.Utils._ +import io.joern.ghidra2cpg.* +import io.joern.ghidra2cpg.utils.Utils.* import io.joern.ghidra2cpg.utils.{Decompiler, PCodeMapper} -import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewMethod} -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, nodes} import io.shiftleft.passes.ForkJoinParallelCpgPass -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions class PCodePass(currentProgram: Program, fileName: String, functions: List[Function], cpg: Cpg, decompiler: Decompiler) diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/arm/ArmFunctionPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/arm/ArmFunctionPass.scala index c0ebcf7d26df..60963b18126e 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/arm/ArmFunctionPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/arm/ArmFunctionPass.scala @@ -5,9 +5,8 @@ import io.joern.ghidra2cpg.utils.Decompiler import io.joern.ghidra2cpg.passes.FunctionPass import io.joern.ghidra2cpg.processors.ArmProcessor import io.joern.ghidra2cpg.utils.Utils.{checkIfExternal, createMethodNode, createReturnNode} -import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.NewBlock -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, nodes} class ArmFunctionPass( currentProgram: Program, diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/LoHiPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/LoHiPass.scala index 2b592a15e403..70f7bbe485dc 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/LoHiPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/LoHiPass.scala @@ -1,10 +1,10 @@ package io.joern.ghidra2cpg.passes.mips import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LoHiPass(cpg: Cpg) extends ForkJoinParallelCpgPass[(Call, Call)](cpg) { override def generateParts(): Array[(Call, Call)] = { @@ -23,6 +23,9 @@ class LoHiPass(cpg: Cpg) extends ForkJoinParallelCpgPass[(Call, Call)](cpg) { }.toArray override def runOnPart(diffGraph: DiffGraphBuilder, pair: (Call, Call)): Unit = { - diffGraph.addEdge(pair._1, pair._2, EdgeTypes.REACHING_DEF, PropertyNames.VARIABLE, pair._1.code) + // in flatgraph an edge may have zero or one properties and they're not named... + // in this case we know that we're dealing with ReachingDef edges which has the `variable` property + val variableProperty = pair._1.code + diffGraph.addEdge(pair._1, pair._2, EdgeTypes.REACHING_DEF, variableProperty) } } diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsFunctionPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsFunctionPass.scala index 6e5a238b570d..71222e896100 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsFunctionPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsFunctionPass.scala @@ -2,12 +2,12 @@ package io.joern.ghidra2cpg.passes.mips import ghidra.program.model.address.GenericAddress import ghidra.program.model.lang.Register import ghidra.program.model.listing.{Function, Instruction, Program} -import ghidra.program.model.pcode.PcodeOp._ +import ghidra.program.model.pcode.PcodeOp.* import ghidra.program.model.pcode.{HighFunction, PcodeOp, PcodeOpAST, Varnode} import ghidra.program.model.scalar.Scalar import io.joern.ghidra2cpg.passes.FunctionPass import io.joern.ghidra2cpg.processors.MipsProcessor -import io.joern.ghidra2cpg.utils.Utils._ +import io.joern.ghidra2cpg.utils.Utils.* import io.joern.ghidra2cpg.Types import io.joern.ghidra2cpg.utils.Decompiler import io.shiftleft.codepropertygraph.generated.Cpg @@ -15,7 +15,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{CfgNodeNew, NewBlock} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import org.slf4j.LoggerFactory -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions class MipsFunctionPass( diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsReturnEdgesPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsReturnEdgesPass.scala index 03f63d30fb6a..468d6a3e82f5 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsReturnEdgesPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/mips/MipsReturnEdgesPass.scala @@ -3,7 +3,7 @@ package io.joern.ghidra2cpg.passes.mips import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} class MipsReturnEdgesPass(cpg: Cpg) extends CpgPass(cpg) { @@ -17,7 +17,10 @@ class MipsReturnEdgesPass(cpg: Cpg) extends CpgPass(cpg) { // the first .cfgNext is skipping a _nop instruction after the call val to = from.cfgNext.cfgNext.isCall.argument.code("v(0|1)").headOption if (to.nonEmpty) { - diffGraph.addEdge(from, to.get, EdgeTypes.REACHING_DEF, PropertyNames.VARIABLE, from.code) + // in flatgraph an edge may have zero or one properties and they're not named... + // in this case we know that we're dealing with ReachingDef edges which has the `variable` property + val variableProperty = from.code + diffGraph.addEdge(from, to.get, EdgeTypes.REACHING_DEF, variableProperty) } } } diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/ReturnEdgesPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/ReturnEdgesPass.scala index 956c55cb58e9..0669e45b7a16 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/ReturnEdgesPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/ReturnEdgesPass.scala @@ -3,7 +3,7 @@ package io.joern.ghidra2cpg.passes.x86 import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory class ReturnEdgesPass(cpg: Cpg) extends CpgPass(cpg) { @@ -15,7 +15,11 @@ class ReturnEdgesPass(cpg: Cpg) extends CpgPass(cpg) { cpg.call.nameNot(".*").foreach { from => // We expect RAX/EAX as return val to = from.cfgNext.isCall.argument.code("(R|E)AX").headOption - if (to.nonEmpty) diffGraph.addEdge(from, to.get, EdgeTypes.REACHING_DEF, PropertyNames.VARIABLE, from.code) + + // in flatgraph an edge may have zero or one properties and they're not named... + // in this case we know that we're dealing with ReachingDef edges which has the `variable` property + val variableProperty = from.code + if (to.nonEmpty) diffGraph.addEdge(from, to.get, EdgeTypes.REACHING_DEF, variableProperty) } } diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/X86FunctionPass.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/X86FunctionPass.scala index c7e5ae1c0134..aff0452836d8 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/X86FunctionPass.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/passes/x86/X86FunctionPass.scala @@ -4,7 +4,7 @@ import ghidra.program.model.listing.{Function, Program} import io.joern.ghidra2cpg.utils.Decompiler import io.joern.ghidra2cpg.passes.FunctionPass import io.joern.ghidra2cpg.processors.X86Processor -import io.joern.ghidra2cpg.utils.Utils._ +import io.joern.ghidra2cpg.utils.Utils.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewMethod} diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/PCodeMapper.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/PCodeMapper.scala index ccd8fa695d6a..d24f15f3b5db 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/PCodeMapper.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/PCodeMapper.scala @@ -2,17 +2,17 @@ package io.joern.ghidra2cpg.utils import ghidra.app.util.template.TemplateSimplifier import ghidra.program.model.listing.{CodeUnitFormat, CodeUnitFormatOptions, Function, Instruction} -import ghidra.program.model.pcode.PcodeOp._ +import ghidra.program.model.pcode.PcodeOp.* import ghidra.program.model.pcode.{HighFunction, PcodeOp, PcodeOpAST, Varnode} import io.joern.ghidra2cpg.Types //import io.joern.ghidra2cpg.utils.Utils.{createCallNode, createIdentifier, createLiteral} -import io.joern.ghidra2cpg.utils.Utils._ +import io.joern.ghidra2cpg.utils.Utils.* import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.CfgNodeNew import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions class State(argumentIndex: Int) { var argument: Int = argumentIndex diff --git a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/Utils.scala b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/Utils.scala index 5185c815db89..1d9128ba886c 100644 --- a/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/Utils.scala +++ b/joern-cli/frontends/ghidra2cpg/src/main/scala/io/joern/ghidra2cpg/utils/Utils.scala @@ -2,11 +2,11 @@ package io.joern.ghidra2cpg.utils import ghidra.program.model.listing.{Function, Instruction, Program} import io.joern.ghidra2cpg.Types -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.proto.cpg.Cpg.DispatchTypes -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.language.implicitConversions object Utils { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/config/ConfigTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/config/ConfigTests.scala index 6e2fa7e922e9..87fd3c6e27c3 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX" // Frontend-specific args diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/DataFlowBinToCpgSuite.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/DataFlowBinToCpgSuite.scala index f8b6abfbdf82..a586f3206acd 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/DataFlowBinToCpgSuite.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/DataFlowBinToCpgSuite.scala @@ -9,12 +9,13 @@ import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.dotextension.ImageViewer import io.shiftleft.semanticcpg.layers.* +import scala.compiletime.uninitialized import scala.sys.process.Process import scala.util.Try class DataFlowBinToCpgSuite extends GhidraBinToCpgSuite { - implicit var context: EngineContext = scala.compiletime.uninitialized + implicit var context: EngineContext = uninitialized override def beforeAll(): Unit = { super.beforeAll() @@ -33,7 +34,7 @@ class DataFlowBinToCpgSuite extends GhidraBinToCpgSuite { new OssDataFlow(options).run(context) } - protected implicit def int2IntegerOption(x: Int): Option[Integer] = + protected implicit def int2IntegerOption(x: Int): Option[Int] = Some(x) protected def getMemberOfType(cpg: Cpg, typeName: String, memberName: String): Iterator[Member] = diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/GhidraBinToCpgSuite.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/GhidraBinToCpgSuite.scala index 07546bea0205..21d619dcb701 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/GhidraBinToCpgSuite.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/fixtures/GhidraBinToCpgSuite.scala @@ -6,8 +6,8 @@ import io.joern.x2cpg.testfixtures.LanguageFrontend import io.shiftleft.utils.ProjectRoot import org.apache.commons.io.FileUtils import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* import java.nio.file.Files diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/CallArgumentsTest.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/CallArgumentsTest.scala index f885b944657c..9fc3cf0d22ec 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/CallArgumentsTest.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/CallArgumentsTest.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.mips import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallArgumentsTest extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowTests.scala index 99f43aace6f6..dc6babe19dae 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowTests.scala @@ -1,13 +1,13 @@ package io.joern.ghidra2cpg.querying.mips -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite import io.joern.x2cpg.X2Cpg.applyDefaultOverlays import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.language.{ICallResolver, _} -import io.shiftleft.semanticcpg.layers._ +import io.shiftleft.semanticcpg.layers.* class DataFlowTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowThroughLoHiRegistersTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowThroughLoHiRegistersTests.scala index e0c1b017f17f..f42b3b965165 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowThroughLoHiRegistersTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/mips/DataFlowThroughLoHiRegistersTests.scala @@ -1,14 +1,14 @@ package io.joern.ghidra2cpg.querying.mips -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.{Parser, Semantics} +import io.joern.dataflowengineoss.semanticsloader.{FullNameSemanticsParser, FullNameSemantics, Semantics} import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite import io.joern.x2cpg.X2Cpg.applyDefaultOverlays import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.layers._ +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.layers.* class DataFlowThroughLoHiRegistersTests extends GhidraBinToCpgSuite { override def passes(cpg: Cpg): Unit = { @@ -37,7 +37,7 @@ class DataFlowThroughLoHiRegistersTests extends GhidraBinToCpgSuite { |".incBy" 1->1 2->1 3->1 4->1 |".rotateRight" 2->1 |""".stripMargin - implicit val semantics: Semantics = Semantics.fromList(new Parser().parse(customSemantics)) + implicit val semantics: Semantics = FullNameSemantics.fromList(new FullNameSemanticsParser().parse(customSemantics)) implicit val context: EngineContext = EngineContext(semantics) "should find flows through `div*` instructions" in { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/CFGTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/CFGTests.scala index 5b9e420a5059..15f43cc1807b 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/CFGTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/CFGTests.scala @@ -2,7 +2,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CFGTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/DataFlowTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/DataFlowTests.scala index c0e439760c4f..de5a4ea2f40b 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/DataFlowTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/DataFlowTests.scala @@ -2,14 +2,14 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite import io.shiftleft.codepropertygraph.generated.Cpg -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext import io.joern.dataflowengineoss.semanticsloader.Semantics import io.joern.dataflowengineoss.DefaultSemantics import io.joern.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.layers._ +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.layers.* class DataFlowTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/FileTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/FileTests.scala index 2d77c4520a6c..645b906d5873 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/FileTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/FileTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.File diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LiteralNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LiteralNodeTests.scala index 2b10f2019861..f7bab3771104 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LiteralNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LiteralNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LiteralNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LocalNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LocalNodeTests.scala index 009cf922c0fa..1a54bd8a3bc4 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LocalNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/LocalNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocalNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MetaDataNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MetaDataNodeTests.scala index f8b5143c4252..9f141e9ff412 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MetaDataNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MetaDataNodeTests.scala @@ -2,7 +2,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MethodNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MethodNodeTests.scala index c4c04ce1c827..85f99d1afa53 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MethodNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/MethodNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/NamespaceBlockTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/NamespaceBlockTests.scala index f9341692c73a..abe9dcd5f9db 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/NamespaceBlockTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/NamespaceBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, NamespaceTraversal} class NamespaceBlockTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ParameterNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ParameterNodeTests.scala index 9f2c03f4b1fb..c37d5ac9c4d5 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ParameterNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ParameterNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ParameterNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/RefNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/RefNodeTests.scala index b3b317abc4f1..2db2de746f84 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/RefNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/RefNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class RefNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ReturnNodeTests.scala b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ReturnNodeTests.scala index 795b879e4f38..dd564ce01662 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ReturnNodeTests.scala +++ b/joern-cli/frontends/ghidra2cpg/src/test/scala/io/joern/ghidra2cpg/querying/x86/ReturnNodeTests.scala @@ -1,7 +1,7 @@ package io.joern.ghidra2cpg.querying.x86 import io.joern.ghidra2cpg.fixtures.GhidraBinToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ReturnNodeTests extends GhidraBinToCpgSuite { diff --git a/joern-cli/frontends/ghidra2cpg/src/test/testbinaries/coverage/testscript.sc b/joern-cli/frontends/ghidra2cpg/src/test/testbinaries/coverage/testscript.sc index 81cab22add9e..2cd441ae3888 100644 --- a/joern-cli/frontends/ghidra2cpg/src/test/testbinaries/coverage/testscript.sc +++ b/joern-cli/frontends/ghidra2cpg/src/test/testbinaries/coverage/testscript.sc @@ -2,7 +2,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.joern.dataflowengineoss.language._ import io.shiftleft.semanticcpg.language._ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment -import overflowdb.traversal._ +import flatgraph.traversal._ @main def main(testBinary: String) = { importCode.ghidra(testBinary) diff --git a/joern-cli/frontends/gosrc2cpg/build.sbt b/joern-cli/frontends/gosrc2cpg/build.sbt index b91c175eaa13..90079f60868c 100644 --- a/joern-cli/frontends/gosrc2cpg/build.sbt +++ b/joern-cli/frontends/gosrc2cpg/build.sbt @@ -37,7 +37,7 @@ lazy val GoAstgenMac = "goastgen-macos" lazy val GoAstgenMacArm = "goastgen-macos-arm64" lazy val goAstGenDlUrl = settingKey[String]("goastgen download url") -goAstGenDlUrl := s"https://github.com/Privado-Inc/goastgen/releases/download/v${goAstGenVersion.value}/" +goAstGenDlUrl := s"https://github.com/joernio/goastgen/releases/download/v${goAstGenVersion.value}/" def hasCompatibleAstGenVersion(goAstGenVersion: String): Boolean = { Try("goastgen -version".!!).toOption.map(_.strip()) match { @@ -78,16 +78,15 @@ goAstGenDlTask := { val goAstGenDir = baseDirectory.value / "bin" / "astgen" goAstGenBinaryNames.value.foreach { fileName => - DownloadHelper.ensureIsAvailable(s"${goAstGenDlUrl.value}$fileName", goAstGenDir / fileName) + val file = goAstGenDir / fileName + DownloadHelper.ensureIsAvailable(s"${goAstGenDlUrl.value}$fileName", file) + // permissions are lost during the download; need to set them manually + file.setExecutable(true, false) } val distDir = (Universal / stagingDirectory).value / "bin" / "astgen" distDir.mkdirs() - IO.copyDirectory(goAstGenDir, distDir) - - // permissions are lost during the download; need to set them manually - goAstGenDir.listFiles().foreach(_.setExecutable(true, false)) - distDir.listFiles().foreach(_.setExecutable(true, false)) + IO.copyDirectory(goAstGenDir, distDir, preserveExecutable = true) } Compile / compile := ((Compile / compile) dependsOn goAstGenDlTask).value @@ -99,3 +98,7 @@ stage := Def .sequential(goAstGenSetAllPlatforms, Universal / stage) .andFinally(System.setProperty("ALL_PLATFORMS", "FALSE")) .value + +/** write the astgen version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("Go-AstGen-Version") -> goAstGenVersion.value) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/resources/application.conf b/joern-cli/frontends/gosrc2cpg/src/main/resources/application.conf index 2d296c7eb816..cd729e0c5ef1 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/resources/application.conf +++ b/joern-cli/frontends/gosrc2cpg/src/main/resources/application.conf @@ -1,3 +1,3 @@ gosrc2cpg { - goastgen_version: "0.17.0" + goastgen_version: "0.1.0" } diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/GoSrc2Cpg.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/GoSrc2Cpg.scala index efbcaa30af5e..9fa86ed511d4 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/GoSrc2Cpg.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/GoSrc2Cpg.scala @@ -23,30 +23,33 @@ class GoSrc2Cpg(goGlobalOption: Option[GoGlobal] = Option(GoGlobal())) extends X def createCpg(config: Config): Try[Cpg] = { withNewEmptyCpg(config.outputPath, config) { (cpg, config) => File.usingTemporaryDirectory("gosrc2cpgOut") { tmpDir => - goGlobalOption - .orElse(Option(GoGlobal())) - .foreach(goGlobal => { - MetaDataPass(cpg, Languages.GOLANG, config.inputPath).createAndApply() - val astGenResult = new AstGenRunner(config).execute(tmpDir).asInstanceOf[GoAstGenRunnerResult] - goMod = Some( - GoModHelper( - Some(config), - astGenResult.parsedModFile - .flatMap(modFile => GoAstJsonParser.readModFile(Paths.get(modFile)).map(x => x)) + MetaDataPass(cpg, Languages.GOLANG, config.inputPath).createAndApply() + val astGenResults = new AstGenRunner(config).executeForGo(tmpDir) + astGenResults.foreach(astGenResult => { + goGlobalOption + .orElse(Option(GoGlobal())) + .foreach(goGlobal => { + goMod = Some( + GoModHelper( + Some(astGenResult.modulePath), + astGenResult.parsedModFile + .flatMap(modFile => GoAstJsonParser.readModFile(Paths.get(modFile)).map(x => x)) + ) ) - ) - goGlobal.mainModule = goMod.flatMap(modHelper => modHelper.getModMetaData().map(mod => mod.module.name)) - InitialMainSrcPass(cpg, astGenResult.parsedFiles, config, goMod.get, goGlobal, tmpDir).createAndApply() - if goGlobal.pkgLevelVarAndConstantAstMap.size() > 0 then - PackageCtorCreationPass(cpg, config, goGlobal).createAndApply() - if (config.fetchDependencies) { - goGlobal.processingDependencies = true - DownloadDependenciesPass(cpg, goMod.get, goGlobal, config).process() - goGlobal.processingDependencies = false - } - AstCreationPass(cpg, astGenResult.parsedFiles, config, goMod.get, goGlobal, tmpDir, report).createAndApply() - report.print() - }) + goGlobal.mainModule = goMod.flatMap(modHelper => modHelper.getModMetaData().map(mod => mod.module.name)) + InitialMainSrcPass(cpg, astGenResult.parsedFiles, config, goMod.get, goGlobal, tmpDir).createAndApply() + if goGlobal.pkgLevelVarAndConstantAstMap.size() > 0 then + PackageCtorCreationPass(cpg, config, goGlobal).createAndApply() + if (config.fetchDependencies) { + goGlobal.processingDependencies = true + DownloadDependenciesPass(cpg, goMod.get, goGlobal, config).process() + goGlobal.processingDependencies = false + } + AstCreationPass(cpg, astGenResult.parsedFiles, config, goMod.get, goGlobal, tmpDir, report) + .createAndApply() + report.print() + }) + }) } } } diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/Main.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/Main.scala index 5112f2ccf523..4b3e9a15b0b6 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/Main.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/Main.scala @@ -3,6 +3,7 @@ package io.joern.gosrc2cpg import io.joern.gosrc2cpg.Frontend.* import io.joern.x2cpg.astgen.AstGenConfig import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser import java.nio.file.Paths @@ -42,10 +43,15 @@ object Frontend { } -object Main extends X2CpgMain(cmdLineParser, new GoSrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new GoSrc2Cpg()) with FrontendHTTPServer[Config, GoSrc2Cpg] { + + override protected def newDefaultConfig(): Config = Config() def run(config: Config, gosrc2cpg: GoSrc2Cpg): Unit = { - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - gosrc2cpg.run(config.withInputPath(absPath)) + if (config.serverMode) { startup() } + else { + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + gosrc2cpg.run(config.withInputPath(absPath)) + } } } diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreator.scala index 3fc1aa1d21c8..1d0a90cbdd2a 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreator.scala @@ -12,7 +12,7 @@ import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.NewNode import io.shiftleft.codepropertygraph.generated.{ModifierTypes, NodeTypes} import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import ujson.Value import java.nio.file.Paths diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPackageConstructorCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPackageConstructorCreator.scala index d7c2f7b6a4e5..31b8d03a1732 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPackageConstructorCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPackageConstructorCreator.scala @@ -7,7 +7,7 @@ import io.joern.x2cpg.astgen.AstGenNodeBuilder import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode, Defines as XDefines} import io.shiftleft.codepropertygraph.generated.NodeTypes import org.apache.commons.lang3.StringUtils -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import ujson.Value import scala.collection.immutable.Set diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala index 926c1e0d0f39..f53ba950eb71 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala @@ -38,7 +38,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case DeclStmt => astForNode(statement.json(ParserKeys.Decl)) case ExprStmt => astsForExpression(createParserNodeInfo(statement.json(ParserKeys.X))) case ForStmt => Seq(astForForStatement(statement)) - case IfStmt => Seq(astForIfStatement(statement)) + case IfStmt => astForIfStatement(statement) case IncDecStmt => Seq(astForIncDecStatement(statement)) case RangeStmt => Seq(astForRangeStatement(statement)) case SwitchStmt => Seq(astForSwitchStatement(statement)) @@ -133,7 +133,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t ast } - private def astForIfStatement(ifStmt: ParserNodeInfo): Ast = { + private def astForIfStatement(ifStmt: ParserNodeInfo): Seq[Ast] = { + // handle init code before condition in if; + val initParserNode = nullSafeCreateParserNodeInfo(ifStmt.json.obj.get(ParserKeys.Init)) + val initAstBlock = blockNode(ifStmt, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(initAstBlock) + val initAst = blockAst(initAstBlock, astsForStatement(initParserNode, 1).toList) + scope.popScope() val conditionParserNode = createParserNodeInfo(ifStmt.json(ParserKeys.Cond)) val conditionAst = astForConditionExpression(conditionParserNode) @@ -159,7 +165,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) case _ => Ast() } - controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst)) + Seq(initAst, controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst))) } private def astForSwitchStatement(switchStmt: ParserNodeInfo): Ast = { diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/InitialMainSrcProcessor.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/InitialMainSrcProcessor.scala index 5893116d563a..d56020f92998 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/InitialMainSrcProcessor.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/InitialMainSrcProcessor.scala @@ -5,7 +5,7 @@ import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo} import io.joern.gosrc2cpg.utils.UtilityConstants.fileSeparateorPattern import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.NewNamespaceBlock -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import ujson.{Arr, Obj, Value} import java.io.File diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/model/GoMod.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/model/GoMod.scala index 619aa1834231..df2d810c3f6f 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/model/GoMod.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/model/GoMod.scala @@ -1,6 +1,5 @@ package io.joern.gosrc2cpg.model -import io.joern.gosrc2cpg.Config import io.joern.gosrc2cpg.utils.UtilityConstants.fileSeparateorPattern import upickle.default.* @@ -9,11 +8,10 @@ import java.util.Set import java.util.concurrent.ConcurrentSkipListSet import scala.util.control.Breaks.* -class GoModHelper(config: Option[Config] = None, meta: Option[GoMod] = None) { +class GoModHelper(modulePath: Option[String] = None, meta: Option[GoMod] = None) { def getModMetaData(): Option[GoMod] = meta def getNameSpace(compilationUnitFilePath: String, pkg: String): String = { - if (meta.isEmpty || compilationUnitFilePath == null || compilationUnitFilePath.isEmpty) { // When there no go.mod file, we don't have the information about the module prefix // In this case we will use package name as a namespace @@ -29,7 +27,7 @@ class GoModHelper(config: Option[Config] = None, meta: Option[GoMod] = None) { // 1. if there is go file inside /first/second/test.go (package main) => '/first/second/main' // 2. /test.go (package main) => 'main' - val remainingpath = compilationUnitFilePath.stripPrefix(config.get.inputPath) + val remainingpath = compilationUnitFilePath.stripPrefix(modulePath.get) val pathTokens = remainingpath.split(fileSeparateorPattern) val tokens = pathTokens.dropRight(1).filterNot(x => x == null || x.trim.isEmpty) :+ pkg return tokens.mkString("/") @@ -39,7 +37,7 @@ class GoModHelper(config: Option[Config] = None, meta: Option[GoMod] = None) { // go.mod (module jorn.io/trial) and /foo.go (package foo) => jorn.io/trial>foo // go.mod (module jorn.io/trial) and /first/foo.go (package first) => jorn.io/trial/first // go.mod (module jorn.io/trial) and /first/foo.go (package bar) => jorn.io/trial/first - val remainingpath = compilationUnitFilePath.stripPrefix(config.get.inputPath) + val remainingpath = compilationUnitFilePath.stripPrefix(modulePath.get) val pathTokens = remainingpath.split(fileSeparateorPattern) // prefixing module name i.e. jorn.io/trial val tokens = meta.get.module.name +: pathTokens.dropRight(1).filterNot(x => x == null || x.trim.isEmpty) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/passes/DownloadDependenciesPass.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/passes/DownloadDependenciesPass.scala index 00b5d2558202..d3234b5bc1db 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/passes/DownloadDependenciesPass.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/passes/DownloadDependenciesPass.scala @@ -27,13 +27,13 @@ class DownloadDependenciesPass(cpg: Cpg, parentGoMod: GoModHelper, goGlobal: GoG parentGoMod .getModMetaData() .foreach(mod => { - ExternalCommand.run("go mod init joern.io/temp", projDir) match { + ExternalCommand.run(Seq("go", "mod", "init", "joern.io/temp"), projDir).toTry match { case Success(_) => mod.dependencies .filter(dep => dep.beingUsed) .map(dependency => { - val cmd = s"go get ${dependency.dependencyStr()}" - val results = ExternalCommand.run(cmd, projDir) + val cmd = Seq("go", "get", dependency.dependencyStr()) + val results = ExternalCommand.run(cmd, projDir).toTry results match { case Success(_) => print(". ") @@ -83,10 +83,11 @@ class DownloadDependenciesPass(cpg: Cpg, parentGoMod: GoModHelper, goGlobal: GoG .withIgnoredFilesRegex(config.ignoredFilesRegex.toString()) .withIgnoredFiles(config.ignoredFiles.toList) val astGenResult = new AstGenRunner(depConfig, dependency.getIncludePackagesList()) - .execute(astLocation) - .asInstanceOf[GoAstGenRunnerResult] + .executeForGo(astLocation) + .headOption + .getOrElse(GoAstGenRunnerResult()) val goMod = new GoModHelper( - Some(depConfig), + Some(dependencyLocation), astGenResult.parsedModFile.flatMap(modFile => GoAstJsonParser.readModFile(Paths.get(modFile)).map(x => x)) ) DependencySrcProcessorPass(cpg, astGenResult.parsedFiles, depConfig, goMod, goGlobal, astLocation) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/AstGenRunner.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/AstGenRunner.scala index 93f06aeea3f7..48fa174edbeb 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/AstGenRunner.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/AstGenRunner.scala @@ -10,12 +10,16 @@ import io.joern.x2cpg.utils.Environment.OperatingSystemType.OperatingSystemType import io.joern.x2cpg.utils.{Environment, ExternalCommand} import org.slf4j.LoggerFactory +import java.nio.file.Paths +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters.* import scala.util.matching.Regex import scala.util.{Failure, Success, Try} object AstGenRunner { private val logger = LoggerFactory.getLogger(getClass) case class GoAstGenRunnerResult( + modulePath: String = "", parsedModFile: Option[String] = None, parsedFiles: List[String] = List.empty, skippedFiles: List[String] = List.empty @@ -71,16 +75,18 @@ class AstGenRunner(config: Config, includeFileRegex: String = "") extends AstGen override def runAstGenNative(in: String, out: File, exclude: String, include: String)(implicit metaData: AstGenProgramMetaData ): Try[Seq[String]] = { - val excludeCommand = if (exclude.isEmpty) "" else s"-exclude \"$exclude\"" - val includeCommand = if (include.isEmpty) "" else s"-include-packages \"$include\"" - ExternalCommand.run(s"$astGenCommand $excludeCommand $includeCommand -out ${out.toString()} $in", ".") + val excludeCommand = if (exclude.isEmpty) Seq.empty else Seq("-exclude", exclude) + val includeCommand = if (include.isEmpty) Seq.empty else Seq("-include-packages", include) + ExternalCommand + .run((astGenCommand +: excludeCommand) ++ includeCommand ++ Seq("-out", out.toString(), in), ".") + .toTry } - override def execute(out: File): AstGenRunnerResult = { + def executeForGo(out: File): List[GoAstGenRunnerResult] = { implicit val metaData: AstGenProgramMetaData = config.astGenMetaData val in = File(config.inputPath) logger.info(s"Running goastgen in '$config.inputPath' ...") - runAstGenNative(config.inputPath, out, config.ignoredFilesRegex.toString(), includeFileRegex.toString()) match { + runAstGenNative(config.inputPath, out, config.ignoredFilesRegex.toString(), includeFileRegex) match { case Success(result) => val srcFiles = SourceFiles.determine( out.toString(), @@ -91,11 +97,108 @@ class AstGenRunner(config: Config, includeFileRegex: String = "") extends AstGen val parsedModFile = filterModFile(srcFiles, out) val parsed = filterFiles(srcFiles, out) val skipped = skippedFiles(in, result.toList) - GoAstGenRunnerResult(parsedModFile.headOption, parsed, skipped) + segregateByModule(config.inputPath, out.toString, parsedModFile, parsed, skipped) case Failure(f) => logger.error("\t- running astgen failed!", f) - GoAstGenRunnerResult() + List() } } + /** Segregate all parsed files including go.mod files under separate modules. This will also segregate modules defined + * inside another module + */ + private def segregateByModule( + inputPath: String, + outPath: String, + parsedModFiles: List[String], + parsedFiles: List[String], + skippedFiles: List[String] + ): List[GoAstGenRunnerResult] = { + val moduleMeta: ModuleMeta = + ModuleMeta(inputPath, outPath, None, ListBuffer[String](), ListBuffer[String](), ListBuffer[ModuleMeta]()) + if (parsedModFiles.nonEmpty) { + parsedModFiles + .sortBy(_.split(UtilityConstants.fileSeparateorPattern).length) + .foreach(modFile => { + moduleMeta.addModFile(modFile, inputPath, outPath) + }) + parsedFiles.foreach(moduleMeta.addParsedFile) + skippedFiles.foreach(moduleMeta.addSkippedFile) + moduleMeta.getOnlyChildren + } else { + parsedFiles.foreach(moduleMeta.addParsedFile) + skippedFiles.foreach(moduleMeta.addSkippedFile) + moduleMeta.getAllChildren + } + } + + private def getParentFolder(path: String): String = { + val parent = Paths.get(path).getParent + if (parent != null) parent.toString else "" + } + + case class ModuleMeta( + modulePath: String, + outputModulePath: String, + modFilePath: Option[String], + parsedFiles: ListBuffer[String], + skippedFiles: ListBuffer[String], + childModules: ListBuffer[ModuleMeta] + ) { + def addModFile(modFile: String, inputPath: String, outPath: String): Unit = { + childModules.collectFirst { + case childMod if modFile.startsWith(childMod.outputModulePath) => + childMod.addModFile(modFile, inputPath, outPath) + } match { + case None => + val outmodpath = getParentFolder(modFile) + childModules.addOne( + ModuleMeta( + outmodpath.replace(outPath, inputPath), + outmodpath, + Some(modFile), + ListBuffer[String](), + ListBuffer[String](), + ListBuffer[ModuleMeta]() + ) + ) + case _ => + } + } + + def addParsedFile(parsedFile: String): Unit = { + childModules.collectFirst { + case childMod if parsedFile.startsWith(childMod.outputModulePath) => + childMod.addParsedFile(parsedFile) + } match { + case None => parsedFiles.addOne(parsedFile) + case _ => + } + } + + def addSkippedFile(skippedFile: String): Unit = { + childModules.collectFirst { + case childMod if skippedFile.startsWith(childMod.outputModulePath) => + childMod.addSkippedFile(skippedFile) + } match { + case None => skippedFiles.addOne(skippedFile) + case _ => + } + } + + def getOnlyChildren: List[GoAstGenRunnerResult] = { + childModules.flatMap(_.getAllChildren).toList + } + + def getAllChildren: List[GoAstGenRunnerResult] = { + getOnlyChildren ++ List( + GoAstGenRunnerResult( + modulePath = modulePath, + parsedModFile = modFilePath, + parsedFiles = parsedFiles.toList, + skippedFiles = skippedFiles.toList + ) + ) + } + } } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/ConditionalsDataflowTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/ConditionalsDataflowTests.scala index a0f4fbac6484..6dd2003ef65a 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/ConditionalsDataflowTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/ConditionalsDataflowTests.scala @@ -1,8 +1,8 @@ package io.joern.go2cpg.dataflow import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite class ConditionalsDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/LoopsDataflowTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/LoopsDataflowTests.scala index ce889f319522..f69b72680ea0 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/LoopsDataflowTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/LoopsDataflowTests.scala @@ -1,8 +1,8 @@ package io.joern.go2cpg.dataflow import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite class LoopsDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/SwitchDataflowTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/SwitchDataflowTests.scala index ad4d2a22445a..9eb77b9272ba 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/SwitchDataflowTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/SwitchDataflowTests.scala @@ -1,8 +1,8 @@ package io.joern.go2cpg.dataflow import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite class SwitchDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/TypeDeclConstructorDataflowTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/TypeDeclConstructorDataflowTests.scala index f2829f37f35e..5f79899ff2fe 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/TypeDeclConstructorDataflowTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/dataflow/TypeDeclConstructorDataflowTests.scala @@ -1,8 +1,8 @@ package io.joern.go2cpg.dataflow import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* class TypeDeclConstructorDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/io/GoSrc2CpgHTTPServerTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/io/GoSrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..d9e1ac60a2ac --- /dev/null +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/io/GoSrc2CpgHTTPServerTests.scala @@ -0,0 +1,84 @@ +package io.joern.go2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class GoSrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("gosrc2cpgTestsHttpTest") + val file = dir / "main.go" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |package main + |func main$indexStr() { + | print("Hello World!") + |} + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.gosrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.gosrc2cpg.Main.stop() + } + + "Using gosrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("gosrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l shouldBe List("""print("Hello World!")""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("gosrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain(s"main$index") + cpg.call.code.l shouldBe List("""print("Hello World!")""") + } + } + } + } + +} diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/model/GoModTest.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/model/GoModTest.scala index 960916e97852..6691ab69e1db 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/model/GoModTest.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/model/GoModTest.scala @@ -17,13 +17,12 @@ class GoModTest extends AnyWordSpec with Matchers with BeforeAndAfterAll { namespace shouldBe "main" } "invalid compilation file unit with main pkg" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) @@ -37,128 +36,121 @@ class GoModTest extends AnyWordSpec with Matchers with BeforeAndAfterAll { } "with .mod file and main pkg 1 use case" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "first" / "second" / "test.go" pathAsString, "main") + goMod.getNameSpace(File(inputPath) / "first" / "second" / "test.go" pathAsString, "main") namespace shouldBe "first/second/main" } "with .mod file and main pkg 2 use case" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + JFile.separator + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "first" / "second" / "test.go" pathAsString, "main") + goMod.getNameSpace(File(inputPath) / "first" / "second" / "test.go" pathAsString, "main") namespace shouldBe "first/second/main" } "with .mod file and main pkg 3 use case" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "test.go" pathAsString, "main") + goMod.getNameSpace(File(inputPath) / "test.go" pathAsString, "main") namespace shouldBe "main" } "with .mod file and pkg other than main matching with folder" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "test.go" pathAsString, "trial") + goMod.getNameSpace(File(inputPath) / "test.go" pathAsString, "trial") namespace shouldBe "joern.io/trial" } "with .mod file, pkg other than main, one level child folder, and package matching with last folder" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "first" / "test.go" pathAsString, "first") + goMod.getNameSpace(File(inputPath) / "first" / "test.go" pathAsString, "first") namespace shouldBe "joern.io/trial/first" } "with .mod file and pkg other than main and not matching with folder" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "test.go" pathAsString, "foo") + goMod.getNameSpace(File(inputPath) / "test.go" pathAsString, "foo") namespace shouldBe "joern.io/trial" } "with .mod file, pkg other than main, one level child folder, and package not matching with last folder" in { - val config = Config() - config.inputPath = File.currentWorkingDirectory.toString() + val inputPath = File.currentWorkingDirectory.toString() + JFile.separator val goMod = new GoModHelper( - Some(config), + Some(inputPath), Some( GoMod( - fileFullPath = File(config.inputPath) / "go.mod" pathAsString, + fileFullPath = File(inputPath) / "go.mod" pathAsString, module = GoModModule("joern.io/trial"), dependencies = List[GoModDependency]() ) ) ) val namespace = - goMod.getNameSpace(File(config.inputPath) / "first" / "test.go" pathAsString, "bar") + goMod.getNameSpace(File(inputPath) / "first" / "test.go" pathAsString, "bar") namespace shouldBe "joern.io/trial/first" } } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConditionalsTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConditionalsTests.scala index 776c8c5a5142..36183ae1d152 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConditionalsTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConditionalsTests.scala @@ -4,8 +4,8 @@ import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.nodes.Call import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import scala.collection.immutable.List diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DeclarationsTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DeclarationsTests.scala index 0d4deb9b345f..827b7812364e 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DeclarationsTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DeclarationsTests.scala @@ -1,8 +1,8 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DownloadDependencyTest.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DownloadDependencyTest.scala index 4453314d54ba..4cc107c8ad32 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DownloadDependencyTest.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/DownloadDependencyTest.scala @@ -225,9 +225,9 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { "not create any entry in method full name to return type map" in { // This should only contain the `main` method return type mapping as main source code is not invoking any of the dependency method. goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.methodMetaMap.size() shouldBe 1 - val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().toList + val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().asScala.toList mainfullname shouldBe "main" val Array(returnType) = metadata.methodMetaMap.values().toArray returnType shouldBe MethodCacheMetaData(Defines.voidTypeName, "main.main()") @@ -236,7 +236,7 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { "not create any entry in struct member to type map" in { // This should be empty as neither main code has defined any struct type nor we are accessing the third party struct type. goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.structTypeMembers.size() shouldBe 0 } } @@ -298,9 +298,9 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { "not create any entry in method full name to return type map" ignore { // This should only contain the `main` method return type mapping as main source code is not invoking any of the dependency method. goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.methodMetaMap.size() shouldBe 1 - val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().toList + val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().asScala.toList mainfullname shouldBe "main" val Array(returnType) = metadata.methodMetaMap.values().toArray returnType shouldBe MethodCacheMetaData(Defines.voidTypeName, "main.main()") @@ -310,7 +310,7 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { "not create any entry in struct member to type map" ignore { // This should be empty as neither main code has defined any struct type nor we are accessing the third party struct type. goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.structTypeMembers.size() shouldBe 0 } } @@ -397,9 +397,9 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { // TODO: While doing the implementation we need update this test // Lambda expression return types are also getting recorded under this map goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.methodMetaMap.size() shouldBe 1 - val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().toList + val List(mainfullname) = metadata.methodMetaMap.keys().asIterator().asScala.toList mainfullname shouldBe "main" val Array(returnType) = metadata.methodMetaMap.values().toArray returnType shouldBe MethodCacheMetaData(Defines.voidTypeName, "main.main()") @@ -412,7 +412,7 @@ class DownloadDependencyTest extends GoCodeToCpgSuite { // 2. Struct Type is being passed as parameter or returned as value of method that is being used. // 3. A method of Struct Type being used. goGlobal.nameSpaceMetaDataMap.size() shouldBe 1 - val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().toArray + val Array(metadata) = goGlobal.nameSpaceMetaDataMap.values().iterator().asScala.toArray metadata.structTypeMembers.size() shouldBe 0 } } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ExpressionsTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ExpressionsTests.scala index aff9c3dda63f..50f14e27db6b 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ExpressionsTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ExpressionsTests.scala @@ -3,8 +3,8 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes class ExpressionsTests extends GoCodeToCpgSuite { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/FileTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/FileTests.scala index f338ba6084aa..357fa250f93f 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/FileTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/FileTests.scala @@ -1,7 +1,7 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import java.io.File diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ImportTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ImportTests.scala index dbacf6940a4a..571d6632a569 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ImportTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ImportTests.scala @@ -1,7 +1,7 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ImportTests extends GoCodeToCpgSuite { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MetaDataTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MetaDataTests.scala index 83662d57b493..98370d93b790 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MetaDataTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MetaDataTests.scala @@ -3,7 +3,7 @@ package io.joern.go2cpg.passes.ast import io.joern.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} import io.shiftleft.codepropertygraph.generated.Languages import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataTests extends GoCodeToCpgSuite { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodCallTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodCallTests.scala index ceb574e867ed..56972e8865a6 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodCallTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodCallTests.scala @@ -7,9 +7,8 @@ import io.shiftleft.codepropertygraph.generated.edges.Ref import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.{jIteratortoTraversal, toNodeTraversal} - import java.io.File + class MethodCallTests extends GoCodeToCpgSuite(withOssDataflow = true) { "Simple method call use case" should { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MultiModuleTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MultiModuleTests.scala new file mode 100644 index 000000000000..d3d117a00c4d --- /dev/null +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MultiModuleTests.scala @@ -0,0 +1,318 @@ +package io.joern.go2cpg.passes.ast + +import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite +import io.shiftleft.semanticcpg.language.* + +import java.io.File +import scala.collection.immutable.List + +class MultiModuleTests extends GoCodeToCpgSuite { + "Module defined under another directory" should { + val cpg = code( + """ + |module joern.io/sample + |go 1.18 + |""".stripMargin, + Seq("module1", "go.mod").mkString(File.separator) + ).moreCode( + """ + |package fpkg + |type Sample struct { + | Name string + |} + |func Woo(a int) int{ + | return 0 + |} + |""".stripMargin, + Seq("module1", "lib", "lib.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/sample/lib" + |func main() { + | var a = fpkg.Woo(10) + | var b = fpkg.Sample{name: "Pandurang"} + | var c = b.Name + | var d fpkg.Sample + |} + |""".stripMargin, + Seq("module1", "main.go").mkString(File.separator) + ) + + "Check METHOD Node" in { + cpg.method("Woo").size shouldBe 1 + val List(x) = cpg.method("Woo").l + x.fullName shouldBe "joern.io/sample/lib.Woo" + x.signature shouldBe "joern.io/sample/lib.Woo(int)int" + } + + "Check CALL Node" in { + val List(x) = cpg.call("Woo").l + x.methodFullName shouldBe "joern.io/sample/lib.Woo" + x.typeFullName shouldBe "int" + } + + "Traversal from call to callee method node" in { + val List(x) = cpg.call("Woo").callee.l + x.fullName shouldBe "joern.io/sample/lib.Woo" + x.isExternal shouldBe false + } + + "Check TypeDecl Node" in { + val List(x) = cpg.typeDecl("Sample").l + x.fullName shouldBe "joern.io/sample/lib.Sample" + } + + "Check LOCAL Nodes" in { + val List(a, b, c, d) = cpg.local.l + a.typeFullName shouldBe "int" + b.typeFullName shouldBe "joern.io/sample/lib.Sample" + c.typeFullName shouldBe "string" + d.typeFullName shouldBe "joern.io/sample/lib.Sample" + } + } + + "Multiple modules defined under one directory" should { + val cpg = code( + """ + |module joern.io/module1 + |go 1.18 + |""".stripMargin, + Seq("module1", "go.mod").mkString(File.separator) + ).moreCode( + """ + |package pkg + |type ModoneSample struct { + | Name string + |} + |func ModoneWoo(a int) int{ + | return 0 + |} + |""".stripMargin, + Seq("module1", "pkg", "lib.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/module1/pkg" + |func main() { + | var a = pkg.ModoneWoo(10) + | var b = pkg.ModoneSample{name: "Pandurang"} + | var c = b.Name + | var d pkg.ModoneSample + |} + |""".stripMargin, + Seq("module1", "main.go").mkString(File.separator) + ).moreCode( + """ + |module joern.io/module2 + |go 1.18 + |""".stripMargin, + Seq("module2", "go.mod").mkString(File.separator) + ).moreCode( + """ + |package pkg + |type ModtwoSample struct { + | Name string + |} + |func ModtwoWoo(a int) int{ + | return 0 + |} + |""".stripMargin, + Seq("module2", "pkg", "lib.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/module2/pkg" + |func main() { + | var a = pkg.ModtwoWoo(10) + | var b = pkg.ModtwoSample{name: "Pandurang"} + | var c = b.Name + | var d pkg.ModtwoSample + |} + |""".stripMargin, + Seq("module2", "main.go").mkString(File.separator) + ) + "Check METHOD Node module 1" in { + cpg.method("ModoneWoo").size shouldBe 1 + val List(x) = cpg.method("ModoneWoo").l + x.fullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.signature shouldBe "joern.io/module1/pkg.ModoneWoo(int)int" + } + + "Check METHOD Node module 2" in { + cpg.method("ModtwoWoo").size shouldBe 1 + val List(x) = cpg.method("ModtwoWoo").l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.signature shouldBe "joern.io/module2/pkg.ModtwoWoo(int)int" + } + + "Check CALL Node module 1" in { + val List(x) = cpg.call("ModoneWoo").l + x.methodFullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.typeFullName shouldBe "int" + } + + "Check CALL Node module 2" in { + val List(x) = cpg.call("ModtwoWoo").l + x.methodFullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.typeFullName shouldBe "int" + } + + "Traversal from call to callee method node module 1" in { + val List(x) = cpg.call("ModoneWoo").callee.l + x.fullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.isExternal shouldBe false + } + + "Traversal from call to callee method node module 2" in { + val List(x) = cpg.call("ModtwoWoo").callee.l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.isExternal shouldBe false + } + + "Check TypeDecl Node module 1" in { + val List(x) = cpg.typeDecl("ModoneSample").l + x.fullName shouldBe "joern.io/module1/pkg.ModoneSample" + } + + "Check TypeDecl Node module 2" in { + val List(x) = cpg.typeDecl("ModtwoSample").l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoSample" + } + + "Check LOCAL Nodes Module 1 and 2" in { + val List(a, b, c, d, e, f, g, h) = cpg.local.l + a.typeFullName shouldBe "int" + b.typeFullName shouldBe "joern.io/module1/pkg.ModoneSample" + c.typeFullName shouldBe "string" + d.typeFullName shouldBe "joern.io/module1/pkg.ModoneSample" + + e.typeFullName shouldBe "int" + f.typeFullName shouldBe "joern.io/module2/pkg.ModtwoSample" + g.typeFullName shouldBe "string" + h.typeFullName shouldBe "joern.io/module2/pkg.ModtwoSample" + } + } + + "Multiple modules defined one inside another" should { + val cpg = code( + """ + |module joern.io/module1 + |go 1.18 + |""".stripMargin, + Seq("module1", "go.mod").mkString(File.separator) + ).moreCode( + """ + |package pkg + |type ModoneSample struct { + | Name string + |} + |func ModoneWoo(a int) int{ + | return 0 + |} + |""".stripMargin, + Seq("module1", "pkg", "lib.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/module1/pkg" + |func main() { + | var a = pkg.ModoneWoo(10) + | var b = pkg.ModoneSample{name: "Pandurang"} + | var c = b.Name + | var d pkg.ModoneSample + |} + |""".stripMargin, + Seq("module1", "main.go").mkString(File.separator) + ).moreCode( + """ + |module joern.io/module2 + |go 1.18 + |""".stripMargin, + Seq("module1", "stage", "src", "module2", "go.mod").mkString(File.separator) + ).moreCode( + """ + |package pkg + |type ModtwoSample struct { + | Name string + |} + |func ModtwoWoo(a int) int{ + | return 0 + |} + |""".stripMargin, + Seq("module1", "stage", "src", "module2", "pkg", "lib.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/module2/pkg" + |func main() { + | var a = pkg.ModtwoWoo(10) + | var b = pkg.ModtwoSample{name: "Pandurang"} + | var c = b.Name + | var d pkg.ModtwoSample + |} + |""".stripMargin, + Seq("module1", "stage", "src", "module2", "main.go").mkString(File.separator) + ) + "Check METHOD Node module 1" in { + cpg.method("ModoneWoo").size shouldBe 1 + val List(x) = cpg.method("ModoneWoo").l + x.fullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.signature shouldBe "joern.io/module1/pkg.ModoneWoo(int)int" + } + + "Check METHOD Node module 2" in { + cpg.method("ModtwoWoo").size shouldBe 1 + val List(x) = cpg.method("ModtwoWoo").l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.signature shouldBe "joern.io/module2/pkg.ModtwoWoo(int)int" + } + + "Check CALL Node module 1" in { + val List(x) = cpg.call("ModoneWoo").l + x.methodFullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.typeFullName shouldBe "int" + } + + "Check CALL Node module 2" in { + val List(x) = cpg.call("ModtwoWoo").l + x.methodFullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.typeFullName shouldBe "int" + } + + "Traversal from call to callee method node module 1" in { + val List(x) = cpg.call("ModoneWoo").callee.l + x.fullName shouldBe "joern.io/module1/pkg.ModoneWoo" + x.isExternal shouldBe false + } + + "Traversal from call to callee method node module 2" in { + val List(x) = cpg.call("ModtwoWoo").callee.l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoWoo" + x.isExternal shouldBe false + } + + "Check TypeDecl Node module 1" in { + val List(x) = cpg.typeDecl("ModoneSample").l + x.fullName shouldBe "joern.io/module1/pkg.ModoneSample" + } + + "Check TypeDecl Node module 2" in { + val List(x) = cpg.typeDecl("ModtwoSample").l + x.fullName shouldBe "joern.io/module2/pkg.ModtwoSample" + } + + "Check LOCAL Nodes Module 1 and 2" in { + val List(a, b, c, d, e, f, g, h) = cpg.local.l + a.typeFullName shouldBe "int" + b.typeFullName shouldBe "joern.io/module2/pkg.ModtwoSample" + c.typeFullName shouldBe "string" + d.typeFullName shouldBe "joern.io/module2/pkg.ModtwoSample" + + e.typeFullName shouldBe "int" + f.typeFullName shouldBe "joern.io/module1/pkg.ModoneSample" + g.typeFullName shouldBe "string" + h.typeFullName shouldBe "joern.io/module1/pkg.ModoneSample" + } + } +} diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/NamespaceBlockTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/NamespaceBlockTests.scala index b9013d929b36..e860d36a142f 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/NamespaceBlockTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/NamespaceBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/OperatorsTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/OperatorsTests.scala index 21238c36cb69..5b56a56d070c 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/OperatorsTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/OperatorsTests.scala @@ -3,7 +3,7 @@ package io.joern.go2cpg.passes.ast import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class OperatorsTests extends GoCodeToCpgSuite { diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeDeclMembersAndMemberMethodsTest.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeDeclMembersAndMemberMethodsTest.scala index 878af00ed211..34d2bc145379 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeDeclMembersAndMemberMethodsTest.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeDeclMembersAndMemberMethodsTest.scala @@ -4,8 +4,8 @@ import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.nodes.Call import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import scala.collection.immutable.List import io.joern.gosrc2cpg.astcreation.Defines diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/testfixtures/GoCodeToCpgSuite.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/testfixtures/GoCodeToCpgSuite.scala index 2e7601ac3979..e5be2e191a4b 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/testfixtures/GoCodeToCpgSuite.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/testfixtures/GoCodeToCpgSuite.scala @@ -1,7 +1,8 @@ package io.joern.go2cpg.testfixtures import better.files.File -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} import io.joern.gosrc2cpg.datastructures.GoGlobal import io.joern.gosrc2cpg.model.GoModHelper @@ -49,11 +50,11 @@ class DefaultTestCpgWithGo(val fileSuffix: String) extends DefaultTestCpg with S class GoCodeToCpgSuite( fileSuffix: String = ".go", withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty + semantics: Semantics = DefaultSemantics() ) extends Code2CpgFixture(() => - new DefaultTestCpgWithGo(fileSuffix).withOssDataflow(withOssDataflow).withExtraFlows(extraFlows) + new DefaultTestCpgWithGo(fileSuffix).withOssDataflow(withOssDataflow).withSemantics(semantics) ) - with SemanticCpgTestFixture(extraFlows) + with SemanticCpgTestFixture(semantics) with Inside { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/javasrc2cpg/build.sbt b/joern-cli/frontends/javasrc2cpg/build.sbt index 130b79c3c0cc..f8c33d985049 100644 --- a/joern-cli/frontends/javasrc2cpg/build.sbt +++ b/joern-cli/frontends/javasrc2cpg/build.sbt @@ -10,7 +10,8 @@ libraryDependencies ++= Seq( "org.projectlombok" % "lombok" % Versions.lombok, "org.scala-lang.modules" %% "scala-parallel-collections" % Versions.scalaParallel, "org.scala-lang.modules" %% "scala-parser-combinators" % Versions.scalaParserCombinators, - "net.lingala.zip4j" % "zip4j" % Versions.zip4j + "net.lingala.zip4j" % "zip4j" % Versions.zip4j, + "org.ow2.asm" % "asm" % Versions.asm, ) enablePlugins(JavaAppPackaging, LauncherJarPlugin) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala index 2b78bf8f68f9..0c19cd972077 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala @@ -1,6 +1,6 @@ package io.joern.javasrc2cpg -import io.joern.javasrc2cpg.passes.{AstCreationPass, TypeInferencePass} +import io.joern.javasrc2cpg.passes.{AstCreationPass, OuterClassRefPass, TypeInferencePass} import io.joern.x2cpg.X2Cpg.withNewEmptyCpg import io.joern.x2cpg.passes.frontend.{JavaConfigFileCreationPass, MetaDataPass, TypeNodePass} import io.joern.x2cpg.X2CpgFrontend @@ -15,8 +15,6 @@ import scala.util.matching.Regex class JavaSrc2Cpg extends X2CpgFrontend[Config] { import JavaSrc2Cpg._ - private val logger = LoggerFactory.getLogger(this.getClass) - override def createCpg(config: Config): Try[Cpg] = { withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => new MetaDataPass(cpg, language, config.inputPath).createAndApply() @@ -24,6 +22,7 @@ class JavaSrc2Cpg extends X2CpgFrontend[Config] { astCreationPass.createAndApply() astCreationPass.sourceParser.cleanupDelombokOutput() astCreationPass.clearJavaParserCaches() + new OuterClassRefPass(cpg).createAndApply() JavaConfigFileCreationPass(cpg).createAndApply() if (!config.skipTypeInfPass) { TypeNodePass.withRegisteredTypes(astCreationPass.global.usedTypes.keys().asScala.toList, cpg).createAndApply() @@ -35,28 +34,24 @@ class JavaSrc2Cpg extends X2CpgFrontend[Config] { } object JavaSrc2Cpg { - val language: String = Languages.JAVASRC - private val logger = LoggerFactory.getLogger(this.getClass) - + val language: String = Languages.JAVASRC val sourceFileExtensions: Set[String] = Set(".java") val DefaultIgnoredFilesRegex: List[Regex] = List(".git", ".mvn", "test").flatMap { directory => List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) } - val DefaultConfig: Config = Config().withDefaultIgnoredFilesRegex(DefaultIgnoredFilesRegex) def apply(): JavaSrc2Cpg = new JavaSrc2Cpg() def showEnv(): Unit = { - val value = - JavaSrcEnvVar.values.foreach { envVar => - val currentValue = Option(System.getenv(envVar.name)).getOrElse("") - println(s"${envVar.name}:") - println(s" Description : ${envVar.description}") - println(s" Current value: $currentValue") - } + JavaSrcEnvVar.values.foreach { envVar => + val currentValue = sys.env.getOrElse(envVar.name, "") + println(s"${envVar.name}:") + println(s"\tDescription : ${envVar.description}") + println(s"\tCurrent value: $currentValue") + } } enum JavaSrcEnvVar(val name: String, val description: String) { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/Main.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/Main.scala index e89fc8d7e883..fc14f8e0fa80 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/Main.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/Main.scala @@ -2,11 +2,16 @@ package io.joern.javasrc2cpg import io.joern.javasrc2cpg.Frontend.* import io.joern.javasrc2cpg.jpastprinter.JavaParserAstPrinter +import io.joern.x2cpg.X2CpgConfig +import io.joern.x2cpg.X2CpgMain import io.joern.x2cpg.frontendspecific.javasrc2cpg -import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} -import io.joern.x2cpg.passes.frontend.{TypeRecoveryParserConfig, XTypeRecovery, XTypeRecoveryConfig} +import io.joern.x2cpg.passes.frontend.TypeRecoveryParserConfig +import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser +import java.util.concurrent.ExecutorService + /** Command line configuration parameters */ final case class Config( @@ -21,7 +26,8 @@ final case class Config( skipTypeInfPass: Boolean = false, dumpJavaparserAsts: Boolean = false, cacheJdkTypeSolver: Boolean = false, - keepTypeArguments: Boolean = false + keepTypeArguments: Boolean = false, + disableTypeFallback: Boolean = false ) extends X2CpgConfig[Config] with TypeRecoveryParserConfig[Config] { def withInferenceJarPaths(paths: Set[String]): Config = { @@ -71,6 +77,10 @@ final case class Config( def withKeepTypeArguments(value: Boolean): Config = { copy(keepTypeArguments = value).withInheritedFields(this) } + + def withDisableTypeFallback(value: Boolean): Config = { + copy(disableTypeFallback = value).withInheritedFields(this) + } } private object Frontend { @@ -128,12 +138,19 @@ private object Frontend { opt[Unit]("keep-type-arguments") .hidden() .action((_, c) => c.withKeepTypeArguments(true)) - .text("Type full names of variables keep their type arguments.") + .text("Type full names of variables keep their type arguments."), + opt[Unit]("disable-type-fallback") + .action((_, c) => c.withDisableTypeFallback(true)) + .text( + "Disables fallback to wildcard imports, unsound type inferences and the Any type (except where no better information is available)." + ) ) } } -object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()) with FrontendHTTPServer[Config, JavaSrc2Cpg] { + + override protected def newDefaultConfig(): Config = Config() override def main(args: Array[String]): Unit = { // TODO: This is a hack to allow users to use the "--show-env" option without having @@ -146,14 +163,14 @@ object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()) { } def run(config: Config, javasrc2Cpg: JavaSrc2Cpg): Unit = { - if (config.showEnv) { - JavaSrc2Cpg.showEnv() - } else if (config.dumpJavaparserAsts) { - JavaParserAstPrinter.printJpAsts(config) - } else { - javasrc2Cpg.run(config) + config match { + case c if c.serverMode => startup() + case c if c.showEnv => JavaSrc2Cpg.showEnv() + case c if c.dumpJavaparserAsts => JavaParserAstPrinter.printJpAsts(c) + case _ => javasrc2Cpg.run(config) } } def getCmdLineParser = cmdLineParser + } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/AstCreator.scala index 3d661070e83c..4230465d48ed 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/AstCreator.scala @@ -1,5 +1,6 @@ package io.joern.javasrc2cpg.astcreation +import com.github.javaparser.ast.`type`.Type import com.github.javaparser.ast.expr.{ AnnotationExpr, BooleanLiteralExpr, @@ -26,7 +27,7 @@ import com.github.javaparser.resolution.declarations.{ import com.github.javaparser.resolution.types.ResolvedType import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap import com.github.javaparser.symbolsolver.JavaSymbolSolver -import io.joern.javasrc2cpg.astcreation.declarations.AstForDeclarationsCreator +import io.joern.javasrc2cpg.astcreation.declarations.{AstForDeclarationsCreator, BinarySignatureCalculator} import io.joern.javasrc2cpg.astcreation.expressions.AstForExpressionsCreator import io.joern.javasrc2cpg.astcreation.statements.AstForStatementsCreator import io.joern.javasrc2cpg.scope.Scope @@ -43,15 +44,17 @@ import io.joern.javasrc2cpg.util.{ BindingTable, BindingTableAdapterForJavaparser, MultiBindingTableAdapterForJavaparser, - NameConstants + NameConstants, + TemporaryNameProvider, + Util } import io.joern.x2cpg.datastructures.Global import io.joern.x2cpg.utils.OffsetUtils -import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode} +import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.{NewClosureBinding, NewFile, NewImport, NewNamespaceBlock} import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable @@ -90,7 +93,7 @@ class AstCreator( val symbolSolver: JavaSymbolSolver, protected val keepTypeArguments: Boolean, val loggedExceptionCounts: scala.collection.concurrent.Map[Class[?], Int] -)(implicit val withSchemaValidation: ValidationMode) +)(implicit val withSchemaValidation: ValidationMode, val disableTypeFallback: Boolean) extends AstCreatorBase(filename) with AstNodeBuilder[Node, AstCreator] with AstForDeclarationsCreator @@ -104,6 +107,9 @@ class AstCreator( private[astcreation] val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver, keepTypeArguments) private[astcreation] val bindingTableCache = mutable.HashMap.empty[String, BindingTable] + private[astcreation] val binarySignatureCalculator: BinarySignatureCalculator = new BinarySignatureCalculator(scope) + + private[astcreation] val tempNameProvider: TemporaryNameProvider = new TemporaryNameProvider /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a DiffGraph containing the * corresponding CPG AST. @@ -130,15 +136,38 @@ class AstCreator( } } + private[astcreation] def getTypeFullName(expectedType: ExpectedType): Option[String] = { + Option.unless(disableTypeFallback)(expectedType.fullName).flatten + } + + private[astcreation] def defaultTypeFallback(typ: Type): String = { + defaultTypeFallback(code(typ)) + } + + private[astcreation] def defaultTypeFallback(typ: String): String = { + if (disableTypeFallback) { + s"${Defines.UnresolvedNamespace}.${Util.stripGenericTypes(typ)}" + } else + TypeConstants.Any + } + + private[astcreation] def defaultTypeFallback(): String = { + TypeConstants.Any + } + + private[astcreation] def isResolvedTypeFullName(typeFullName: String): Boolean = { + typeFullName != TypeConstants.Any && !typeFullName.startsWith(Defines.UnresolvedNamespace) + } + /** Custom printer that omits comments. To be used by [[code]] */ private val codePrinterOptions = new DefaultPrinterConfiguration() .removeOption(new DefaultConfigurationOption(ConfigOption.PRINT_COMMENTS)) .removeOption(new DefaultConfigurationOption(ConfigOption.PRINT_JAVADOC)) - protected def line(node: Node): Option[Int] = node.getBegin.map(x => x.line).toScala - protected def column(node: Node): Option[Int] = node.getBegin.map(x => x.column).toScala - protected def lineEnd(node: Node): Option[Int] = node.getEnd.map(x => x.line).toScala - protected def columnEnd(node: Node): Option[Int] = node.getEnd.map(x => x.column).toScala + protected def line(node: Node): Option[Int] = node.getBegin.map(_.line).toScala + protected def column(node: Node): Option[Int] = node.getBegin.map(_.column).toScala + protected def lineEnd(node: Node): Option[Int] = node.getEnd.map(_.line).toScala + protected def columnEnd(node: Node): Option[Int] = node.getEnd.map(_.column).toScala protected def code(node: Node): String = node.toString(codePrinterOptions) private val lineOffsetTable = OffsetUtils.getLineOffsetTable(fileContent) @@ -348,8 +377,10 @@ class AstCreator( case _ => None } - def argumentTypesForMethodLike(maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration]): Option[List[String]] = { - maybeResolvedMethodLike.toOption + def argumentTypesForMethodLike( + maybeResolvedMethodLike: Option[ResolvedMethodLikeDeclaration] + ): Option[List[String]] = { + maybeResolvedMethodLike .flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty())) } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForMethodsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForMethodsCreator.scala index 9352e254b8f8..596327cf7978 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForMethodsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForMethodsCreator.scala @@ -4,6 +4,7 @@ import io.joern.x2cpg.utils.AstPropertiesUtil.* import com.github.javaparser.ast.NodeList import com.github.javaparser.ast.body.{ CallableDeclaration, + CompactConstructorDeclaration, ConstructorDeclaration, FieldDeclaration, MethodDeclaration, @@ -11,7 +12,11 @@ import com.github.javaparser.ast.body.{ VariableDeclarator } import com.github.javaparser.ast.stmt.{BlockStmt, ExplicitConstructorInvocationStmt} -import com.github.javaparser.resolution.declarations.{ResolvedMethodDeclaration, ResolvedMethodLikeDeclaration} +import com.github.javaparser.resolution.declarations.{ + ResolvedMethodDeclaration, + ResolvedMethodLikeDeclaration, + ResolvedParameterDeclaration +} import com.github.javaparser.resolution.types.ResolvedType import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} @@ -21,25 +26,32 @@ import io.joern.x2cpg.utils.NodeBuilders import io.joern.x2cpg.utils.NodeBuilders.* import io.joern.x2cpg.{Ast, Defines} import io.shiftleft.codepropertygraph.generated.nodes.{ + AstNodeNew, NewBlock, + NewCall, + NewFieldIdentifier, NewIdentifier, NewMethod, NewMethodParameterIn, NewMethodReturn, NewModifier } -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes} +import io.shiftleft.codepropertygraph.generated.{ + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + ModifierTypes, + NodeTypes, + Operators, + nodes +} import io.joern.javasrc2cpg.scope.JavaScopeElement.fullName import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional import scala.util.{Failure, Success, Try} -import io.shiftleft.codepropertygraph.generated.nodes.AstNodeNew -import io.shiftleft.codepropertygraph.generated.nodes.NewCall -import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.EdgeTypes import com.github.javaparser.ast.Node +import com.github.javaparser.ast.`type`.ClassOrInterfaceType import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserParameterDeclaration import io.joern.javasrc2cpg.astcreation.declarations.AstForMethodsCreator.PartialConstructorDeclaration import io.joern.javasrc2cpg.util.{NameConstants, Util} @@ -49,8 +61,9 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => val methodNode = createPartialMethod(methodDeclaration) val typeParameters = getIdentifiersForTypeParameters(methodDeclaration) + methodDeclaration.getType - val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()) + val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()).toOption val expectedReturnType = tryWithSafeStackOverflow( symbolSolver.toResolvedType(methodDeclaration.getType, classOf[ResolvedType]) ).toOption @@ -59,11 +72,9 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => val returnTypeFullName = expectedReturnType .flatMap(typeInfoCalc.fullName) .orElse(simpleMethodReturnType.flatMap(scope.lookupType(_))) - .orElse( - tryWithSafeStackOverflow(methodDeclaration.getType.asClassOrInterfaceType).toOption.flatMap(t => - scope.lookupType(t.getNameAsString) - ) - ) + .orElse(tryWithSafeStackOverflow(methodDeclaration.getType).toOption.collect { case t: ClassOrInterfaceType => + scope.lookupType(t.getNameAsString) + }.flatten) .orElse(typeParameters.find(typeParam => simpleMethodReturnType.contains(typeParam.name)).map(_.typeFullName)) scope.pushMethodScope( @@ -71,9 +82,12 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => ExpectedType(returnTypeFullName, expectedReturnType), methodDeclaration.isStatic() ) - typeParameters.foreach { typeParameter => scope.addTopLevelType(typeParameter.name, typeParameter.typeFullName) } + typeParameters.foreach { typeParameter => scope.addTypeParameter(typeParameter.name, typeParameter.typeFullName) } + + val genericSignature = binarySignatureCalculator.methodBinarySignature(methodDeclaration) + methodNode.genericSignature(genericSignature) - val parameterAsts = astsForParameterList(methodDeclaration.getParameters) + val parameterAsts = astsForParameterList(methodDeclaration.getParameters.asScala.toList) val parameterTypes = argumentTypesForMethodLike(maybeResolved) val signature = composeSignature(returnTypeFullName, parameterTypes, parameterAsts.size) val namespaceName = scope.enclosingTypeDecl.fullName.getOrElse(Defines.UnresolvedNamespace) @@ -85,20 +99,28 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => val thisNode = Option.when(!methodDeclaration.isStatic) { val typeFullName = scope.enclosingTypeDecl.fullName - thisNodeForMethod(typeFullName, line(methodDeclaration)) + thisNodeForMethod(typeFullName, line(methodDeclaration), column(methodDeclaration)) } val thisAst = thisNode.map(Ast(_)).toList thisNode.foreach { node => - scope.enclosingMethod.get.addParameter(node) + scope.enclosingMethod.get.addParameter(node, scope.enclosingTypeDecl.get.typeDecl.genericSignature) } - val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast(NewBlock())) + val bodyAst = methodDeclaration.getBody.toScala + .map(astForBlockStatement(_, includeTemporaryLocals = true)) + .getOrElse(Ast(NewBlock())) val (lineNr, columnNr) = tryWithSafeStackOverflow(methodDeclaration.getType) match { case Success(typ) => (line(typ), column(typ)) case Failure(_) => (line(methodDeclaration), column(methodDeclaration)) } - val methodReturn = newMethodReturnNode(returnTypeFullName.getOrElse(TypeConstants.Any), None, lineNr, columnNr) + val methodReturn = + newMethodReturnNode( + returnTypeFullName.getOrElse(defaultTypeFallback(methodDeclaration.getType)), + None, + lineNr, + columnNr + ) val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq @@ -109,6 +131,61 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturn, modifiers, annotationAsts) } + private[declarations] def astForRecordParameterAccessor( + parameter: Parameter, + recordTypeFullName: String, + parameterName: String, + parameterTypeFullName: String + ): Ast = { + val signature = + if (isResolvedTypeFullName(parameterTypeFullName)) + composeSignature(Option(parameterTypeFullName), Option(Nil), 0) + else + composeSignature(None, Option(Nil), 0) + + val methodFullName = composeMethodFullName(recordTypeFullName, parameterName, signature) + + val methodReturn = + newMethodReturnNode(parameterTypeFullName, line = line(parameter), column = column(parameter)) + + val genericSignature = binarySignatureCalculator.recordParameterAccessorBinarySignature(parameter) + val methodRoot = methodNode( + parameter, + parameterName, + s"public ${code(parameter.getType)} ${parameterName}()", + methodFullName, + Option(signature), + filename, + Option(NodeTypes.TYPE_DECL), + Option(recordTypeFullName), + genericSignature = Option(genericSignature) + ) + + val modifier = newModifierNode(ModifierTypes.PUBLIC) + + val thisParameter = thisNodeForMethod(Option(recordTypeFullName), line(parameter), column(parameter)) + + val thisIdentifier = identifierNode(parameter, thisParameter.name, thisParameter.code, recordTypeFullName) + val thisIdentifierAst = Ast(thisIdentifier).withRefEdge(thisIdentifier, thisParameter) + val fieldIdentifier = fieldIdentifierNode(parameter, parameterName, parameterName) + + val fieldAccessNode = newOperatorCallNode( + Operators.fieldAccess, + s"${thisIdentifier.code}.${fieldIdentifier.code}", + Option(parameterTypeFullName), + line(parameter), + column(parameter) + ) + val fieldAccessCall = callAst(fieldAccessNode, thisIdentifierAst :: Ast(fieldIdentifier) :: Nil) + + val returnStmt = returnNode(parameter, s"return ${fieldAccessNode.code}") + val returnAst = Ast(returnStmt).withChild(fieldAccessCall) + + val methodBodyAst = blockAst(blockNode(parameter), returnAst :: Nil) + + methodAst(methodRoot, Ast(thisParameter) :: Nil, methodBodyAst, methodReturn, modifier :: Nil) + } + private def abstractModifierForCallable( callableDeclaration: CallableDeclaration[?], isInterfaceMethod: Boolean @@ -123,13 +200,23 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => } } - private def modifiersForMethod(methodDeclaration: CallableDeclaration[?]): List[NewModifier] = { + private def modifiersForMethod( + methodDeclaration: CallableDeclaration[?] | CompactConstructorDeclaration + ): List[NewModifier] = { val isInterfaceMethod = scope.enclosingTypeDecl.isInterface - val abstractModifier = abstractModifierForCallable(methodDeclaration, isInterfaceMethod) + val abstractModifier = Option + .when(methodDeclaration.isCallableDeclaration)( + abstractModifierForCallable(methodDeclaration.asCallableDeclaration(), isInterfaceMethod) + ) + .flatten - val staticVirtualModifierType = if (methodDeclaration.isStatic) ModifierTypes.STATIC else ModifierTypes.VIRTUAL - val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) + // TODO: The opposite of static is not virtual + val staticVirtualModifierType = + if (methodDeclaration.isCallableDeclaration && methodDeclaration.asCallableDeclaration().isStatic) + ModifierTypes.STATIC + else ModifierTypes.VIRTUAL + val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) val accessModifierType = if (methodDeclaration.isPublic) { Some(ModifierTypes.PUBLIC) @@ -148,7 +235,7 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => List(accessModifier, abstractModifier, staticVirtualModifier).flatten } - private def getIdentifiersForTypeParameters(methodDeclaration: MethodDeclaration): List[NewIdentifier] = { + private def getIdentifiersForTypeParameters(methodDeclaration: CallableDeclaration[?]): List[NewIdentifier] = { methodDeclaration.getTypeParameters.asScala.map { typeParameter => val name = typeParameter.getNameAsString val typeFullName = tryWithSafeStackOverflow(typeParameter.getTypeBound.asScala.headOption).toOption.flatten @@ -174,32 +261,51 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => fieldDeclaration.getVariables.asScala.filter(_.getInitializer.isPresent).toList.flatMap { variableDeclaration => scope.pushFieldDeclScope(fieldDeclaration.isStatic, variableDeclaration.getNameAsString) val assignmentAsts = astsForVariableDeclarator(variableDeclaration, fieldDeclaration) + val patternAsts = scope.enclosingMethod.get.getUnaddedPatternVariableAstsAndMarkAdded() scope.popFieldDeclScope() - assignmentAsts + patternAsts ++ assignmentAsts } } } def astForDefaultConstructor(originNode: Node, instanceFieldDeclarations: List[FieldDeclaration]): Ast = { + val parameters = scope.enclosingTypeDecl.get.recordParameters + val genericSignature = binarySignatureCalculator.defaultConstructorSignature(parameters) + val constructorNode = NewMethod() + .name(io.joern.x2cpg.Defines.ConstructorMethodName) + .filename(filename) + .isExternal(false) + .genericSignature(genericSignature) + .lineNumber(line(originNode)) + .columnNumber(column(originNode)) + scope.pushMethodScope(constructorNode, ExpectedType.Void, isStatic = false) + + val parameterAsts = parameters.zipWithIndex.map { case (param, idx) => + astForParameter(param, idx + 1) + } + val parameterTypes = parameterAsts.map(_.rootType.getOrElse(defaultTypeFallback())) + val resolvedParameterTypes = Option.when(parameterTypes.forall(isResolvedTypeFullName))(parameterTypes) + val typeFullName = scope.enclosingTypeDecl.fullName - val signature = s"${TypeConstants.Void}()" + val signature = composeSignature(Option(TypeConstants.Void), resolvedParameterTypes, parameterAsts.size) val fullName = composeMethodFullName( typeFullName.getOrElse(Defines.UnresolvedNamespace), Defines.ConstructorMethodName, signature ) - val constructorNode = NewMethod() - .name(io.joern.x2cpg.Defines.ConstructorMethodName) - .fullName(fullName) - .signature(signature) - .filename(filename) - .isExternal(false) - scope.pushMethodScope(constructorNode, ExpectedType.Void, isStatic = false) + constructorNode.fullName(fullName) + constructorNode.signature(signature) - val thisNode = thisNodeForMethod(typeFullName, lineNumber = None) - scope.enclosingMethod.foreach(_.addParameter(thisNode)) - val bodyStatementAsts = astsForFieldInitializers(instanceFieldDeclarations) + val thisNode = thisNodeForMethod(typeFullName, lineNumber = None, columnNumber = None) + scope.enclosingMethod.foreach(_.addParameter(thisNode, scope.enclosingTypeDecl.get.typeDecl.genericSignature)) + val recordParameterAssignments = parameterAsts + .flatMap(_.nodes) + .collect { case param: nodes.NewMethodParameterIn => param } + .map(astForEponymousFieldAssignment(thisNode, _)) + val bodyStatementAsts = + astsForFieldInitializers(instanceFieldDeclarations) ++ recordParameterAssignments + val temporaryLocalAsts = scope.enclosingMethod.map(_.getTemporaryLocals).getOrElse(Nil).map(Ast(_)) val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None) @@ -209,8 +315,8 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => originNode, constructorNode, thisNode, - explicitParameterAsts = Nil, - bodyStatementAsts = bodyStatementAsts, + explicitParameterAsts = parameterAsts, + bodyStatementAsts = temporaryLocalAsts ++ bodyStatementAsts, methodReturn = returnNode, annotationAsts = Nil, modifiers = modifiers, @@ -222,6 +328,52 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => constructorAst } + private def astForEponymousFieldAssignment( + thisParam: NewMethodParameterIn, + recordParameter: NewMethodParameterIn + ): Ast = { + val thisIdentifier = NewIdentifier() + .name(thisParam.name) + .code(thisParam.name) + .typeFullName(thisParam.typeFullName) + .lineNumber(recordParameter.lineNumber) + .columnNumber(recordParameter.columnNumber) + .dynamicTypeHintFullName(thisParam.dynamicTypeHintFullName) + val thisIdentifierAst = Ast(thisIdentifier).withRefEdge(thisIdentifier, thisParam) + + val fieldIdentifier = NewFieldIdentifier() + .canonicalName(recordParameter.name) + .code(recordParameter.name) + + val fieldAccessNode = newOperatorCallNode( + Operators.fieldAccess, + s"${thisIdentifier.code}.${fieldIdentifier.code}", + Option(recordParameter.typeFullName), + recordParameter.lineNumber, + recordParameter.columnNumber + ) + val fieldAccessAst = callAst(fieldAccessNode, thisIdentifierAst :: Ast(fieldIdentifier) :: Nil) + + val recordParamIdentifier = NewIdentifier() + .name(recordParameter.name) + .code(recordParameter.name) + .typeFullName(recordParameter.typeFullName) + .lineNumber(recordParameter.lineNumber) + .columnNumber(recordParameter.columnNumber) + .dynamicTypeHintFullName(recordParameter.dynamicTypeHintFullName) + val recordParamIdentifierAst = Ast(recordParamIdentifier).withRefEdge(recordParamIdentifier, recordParameter) + + val assignmentNode = newOperatorCallNode( + Operators.assignment, + s"${fieldAccessNode.code} = ${recordParamIdentifier.code}", + Option(recordParameter.typeFullName), + recordParameter.lineNumber, + recordParameter.columnNumber + ) + + callAst(assignmentNode, fieldAccessAst :: recordParamIdentifierAst :: Nil) + } + private def astForParameter(parameter: Parameter, childNum: Int): Ast = { val maybeArraySuffix = if (parameter.isVarArgs) "[]" else "" val rawParameterTypeName = @@ -249,7 +401,8 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => val annotationAsts = parameter.getAnnotations.asScala.map(astForAnnotationExpr) val ast = Ast(parameterNode) - scope.enclosingMethod.get.addParameter(parameterNode) + scope.enclosingMethod.get + .addParameter(parameterNode, binarySignatureCalculator.variableBinarySignature(parameter.getType)) ast.withChildren(annotationAsts) } @@ -258,23 +411,29 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => methodLike: ResolvedMethodLikeDeclaration, typeParamValues: ResolvedTypeParametersMap ): Option[List[String]] = { - val parameterTypes = - Range(0, methodLike.getNumberOfParams) - .flatMap { index => - Try(methodLike.getParam(index)).toOption - } - .map { param => - tryWithSafeStackOverflow(param.getType).toOption - .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) - // In a scenario where we have an import of an external type e.g. `import foo.bar.Baz` and - // this parameter's type is e.g. `Baz`, the lookup will fail. However, if we lookup - // for `Baz` instead (i.e. without type arguments), then the lookup will succeed. - .orElse( - Try( - param.asInstanceOf[JavaParserParameterDeclaration].getWrappedNode.getType.asClassOrInterfaceType - ).toOption.flatMap(t => scope.lookupType(t.getNameAsString)) - ) - } + val parameters = + Range(0, methodLike.getNumberOfParams).flatMap { index => + Try(methodLike.getParam(index)).toOption + }.toList + + calcParameterTypes(parameters, typeParamValues) + } + + def calcParameterTypes( + parameters: List[ResolvedParameterDeclaration], + typeParamValues: ResolvedTypeParametersMap + ): Option[List[String]] = { + val parameterTypes = parameters.map { param => + tryWithSafeStackOverflow(param.getType).toOption + .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) + // In a scenario where we have an import of an external type e.g. `import foo.bar.Baz` and + // this parameter's type is e.g. `Baz`, the lookup will fail. However, if we lookup + // for `Baz` instead (i.e. without type arguments), then the lookup will succeed. + .orElse( + Try(param.asInstanceOf[JavaParserParameterDeclaration].getWrappedNode.getType.asClassOrInterfaceType).toOption + .flatMap(t => scope.lookupType(t.getNameAsString)) + ) + } toOptionList(parameterTypes) } @@ -302,27 +461,47 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => composeSignature(maybeReturnType, maybeParameterTypes, method.getNumberOfParams) } - private def astsForParameterList(parameters: NodeList[Parameter]): Seq[Ast] = { - parameters.asScala.toList.zipWithIndex.map { case (param, idx) => + + private def astsForParameterList(parameters: List[Parameter]): Seq[Ast] = { + parameters.zipWithIndex.map { case (param, idx) => astForParameter(param, idx + 1) } } private def partialConstructorAsts( - constructorDeclarations: List[ConstructorDeclaration], + constructorDeclarations: List[ConstructorDeclaration | CompactConstructorDeclaration], instanceFieldDeclarations: List[FieldDeclaration] ): List[PartialConstructorDeclaration] = { constructorDeclarations.map { constructorDeclaration => + val maybeResolved = Option + .when(constructorDeclaration.isConstructorDeclaration)( + tryWithSafeStackOverflow(constructorDeclaration.resolve()).toOption + ) + .flatten val constructorNode = createPartialMethod(constructorDeclaration) .name(io.joern.x2cpg.Defines.ConstructorMethodName) scope.pushMethodScope(constructorNode, ExpectedType.Void, isStatic = false) - val maybeResolved = tryWithSafeStackOverflow(constructorDeclaration.resolve()) + constructorDeclaration match { + case regularConstructor: ConstructorDeclaration => + val typeParameters = getIdentifiersForTypeParameters(regularConstructor) + typeParameters.foreach(typeParam => scope.addTypeParameter(typeParam.name, typeParam.typeFullName)) + case _ => // Compact constructor cannot have type parameters + } - val parameterAsts = astsForParameterList(constructorDeclaration.getParameters).toList - val paramTypes = argumentTypesForMethodLike(maybeResolved) - val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) - val typeFullName = scope.enclosingTypeDecl.fullName + val parameters = constructorDeclaration match { + case regularConstructor: ConstructorDeclaration => regularConstructor.getParameters.asScala.toList + case compactConstructor: CompactConstructorDeclaration => scope.enclosingTypeDecl.get.recordParameters + } + val parameterAsts = astsForParameterList(parameters).toList + val paramTypes = constructorDeclaration match { + case constructor: ConstructorDeclaration => argumentTypesForMethodLike(maybeResolved) + case constructor: CompactConstructorDeclaration => + val resolvedParams = parameters.flatMap(param => tryWithSafeStackOverflow(param.resolve()).toOption).toList + calcParameterTypes(resolvedParams, ResolvedTypeParametersMap.empty()) + } + val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) + val typeFullName = scope.enclosingTypeDecl.fullName val fullName = composeMethodFullName( typeFullName.getOrElse(Defines.UnresolvedNamespace), @@ -334,30 +513,44 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => .fullName(fullName) .signature(signature) - parameterAsts.foreach { ast => + parameterAsts.zip(parameters).foreach { (ast, parameterNode) => ast.root match { - case Some(parameter: NewMethodParameterIn) => scope.enclosingMethod.get.addParameter(parameter) - case _ => // This should never happen + case Some(parameter: NewMethodParameterIn) => + val genericType = binarySignatureCalculator.variableBinarySignature(parameterNode.getType) + scope.enclosingMethod.get.addParameter(parameter, genericType) + case _ => // This should never happen } } - val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration)) - scope.enclosingMethod.get.addParameter(thisNode) + val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration), column(constructorDeclaration)) + scope.enclosingMethod.get.addParameter(thisNode, scope.enclosingTypeDecl.get.typeDecl.genericSignature) scope.pushBlockScope() + val recordParameterAssignments = constructorDeclaration match { + case constructor: CompactConstructorDeclaration => + parameterAsts + .flatMap(_.nodes) + .collect { case param: nodes.NewMethodParameterIn => param } + .map(astForEponymousFieldAssignment(thisNode, _)) + case _ => Nil + } + val bodyStatements = constructorDeclaration.getBody.getStatements.asScala.toList + val statementsAsts = bodyStatements.flatMap(astsForStatement) val bodyContainsThis = bodyStatements.headOption .collect { case consInvocation: ExplicitConstructorInvocationStmt => consInvocation.isThis } .getOrElse(false) - val fieldAssignments = + val fieldAssignmentsAndTempLocals = if (bodyContainsThis) Nil else - astsForFieldInitializers(instanceFieldDeclarations) + scope.enclosingMethod.get.getTemporaryLocals.map(Ast(_)) ++ astsForFieldInitializers( + instanceFieldDeclarations + ) - // The this(...) call must always be the first statement in the body, but adding the fieldAssignments + // The this(...) call must always be the first statement in the body, but adding the fieldAssignmentsAndTempLocals // before the body asts here is safe, since the list will be empty if the body does start with this() - val bodyAsts = fieldAssignments ++ bodyStatements.flatMap(astsForStatement) + val bodyAsts = recordParameterAssignments ++ fieldAssignmentsAndTempLocals ++ statementsAsts scope.popBlockScope() val methodReturn = constructorReturnNode(constructorDeclaration) @@ -454,7 +647,7 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => } def astsForConstructors( - constructorDeclarations: List[ConstructorDeclaration], + constructorDeclarations: List[ConstructorDeclaration | CompactConstructorDeclaration], instanceFieldDeclarations: List[FieldDeclaration] ): Map[Node, Ast] = { val partialConstructors = partialConstructorAsts(constructorDeclarations, instanceFieldDeclarations) @@ -463,9 +656,11 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => }.toMap } - private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration): NewMethodReturn = { - val line = constructorDeclaration.getEnd.map(x => x.line).toScala - val column = constructorDeclaration.getEnd.map(x => x.column).toScala + private def constructorReturnNode( + constructorDeclaration: ConstructorDeclaration | CompactConstructorDeclaration + ): NewMethodReturn = { + val line = constructorDeclaration.getEnd.map(_.line).toScala + val column = constructorDeclaration.getEnd.map(_.column).toScala newMethodReturnNode(TypeConstants.Void, None, line, column) } @@ -490,22 +685,45 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator => /** Constructor and Method declarations share a lot of fields, so this method adds the fields they have in common. * `fullName` and `signature` are omitted */ - private def createPartialMethod(declaration: CallableDeclaration[?]): NewMethod = { - val code = declaration.getDeclarationAsString.trim + private def createPartialMethod(declaration: CallableDeclaration[?] | CompactConstructorDeclaration): NewMethod = { + val methodCode = declaration match { + case callableDeclaration: CallableDeclaration[?] => callableDeclaration.getDeclarationAsString.trim + case compactConstructor: CompactConstructorDeclaration => code(compactConstructor) + } val columnNumber = declaration.getBegin.map(x => Integer.valueOf(x.column)).toScala val endLine = declaration.getEnd.map(x => Integer.valueOf(x.line)).toScala val endColumn = declaration.getEnd.map(x => Integer.valueOf(x.column)).toScala val placeholderFullName = "" - methodNode(declaration, declaration.getNameAsString(), code, placeholderFullName, None, filename) + + val genericSignature = declaration match { + case callableDeclaration: CallableDeclaration[_] => + binarySignatureCalculator.methodBinarySignature(callableDeclaration) + case compactConstructor: CompactConstructorDeclaration => + binarySignatureCalculator.defaultConstructorSignature(scope.enclosingTypeDecl.get.recordParameters) + } + methodNode( + declaration, + declaration.getNameAsString(), + methodCode, + placeholderFullName, + None, + filename, + genericSignature = Option(genericSignature) + ) } - def thisNodeForMethod(maybeTypeFullName: Option[String], lineNumber: Option[Int]): NewMethodParameterIn = { - val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(TypeConstants.Any)) + def thisNodeForMethod( + maybeTypeFullName: Option[String], + lineNumber: Option[Int], + columnNumber: Option[Int] + ): NewMethodParameterIn = { + val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(defaultTypeFallback())) NodeBuilders.newThisParameterNode( typeFullName = typeFullName, dynamicTypeHintFullName = maybeTypeFullName.toSeq, - line = lineNumber + line = lineNumber, + column = columnNumber ) } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala index c036d6e016f4..1567c51ef8b4 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala @@ -2,13 +2,17 @@ package io.joern.javasrc2cpg.astcreation.declarations import com.github.javaparser.ast.body.{ AnnotationDeclaration, + AnnotationMemberDeclaration, BodyDeclaration, ClassOrInterfaceDeclaration, + CompactConstructorDeclaration, ConstructorDeclaration, EnumConstantDeclaration, + EnumDeclaration, FieldDeclaration, InitializerDeclaration, MethodDeclaration, + RecordDeclaration, TypeDeclaration, VariableDeclarator } @@ -37,7 +41,7 @@ import com.github.javaparser.resolution.declarations.{ ResolvedReferenceTypeDeclaration, ResolvedTypeParameterDeclaration } -import io.joern.javasrc2cpg.astcreation.AstCreator +import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.javasrc2cpg.util.{BindingTable, BindingTableEntry, NameConstants, Util} import io.joern.x2cpg.utils.NodeBuilders.* @@ -57,9 +61,6 @@ import scala.jdk.CollectionConverters.* import scala.util.{Success, Try} import com.github.javaparser.ast.expr.ObjectCreationExpr import com.github.javaparser.ast.stmt.LocalClassDeclarationStmt -import com.github.javaparser.ast.body.AnnotationMemberDeclaration -import com.github.javaparser.ast.body.CompactConstructorDeclaration -import com.github.javaparser.ast.body.EnumDeclaration import io.joern.javasrc2cpg.scope.Scope.ScopeVariable import com.github.javaparser.ast.Node import com.github.javaparser.resolution.types.ResolvedReferenceType @@ -90,6 +91,10 @@ object AstForTypeDeclsCreator { private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => private val logger = LoggerFactory.getLogger(this.getClass) + private def outerClassGenericSignature: Option[String] = { + scope.enclosingTypeDecl.map(decl => binarySignatureCalculator.variableBinarySignature(decl.typeDecl.name)) + } + def astForAnonymousClassDecl( expr: ObjectCreationExpr, body: List[BodyDeclaration[?]], @@ -99,6 +104,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => ): Ast = { val (astParentType, astParentFullName) = getAstParentInfo() + val genericSignature = binarySignatureCalculator.variableBinarySignature(expr.getType) val typeDeclRoot = typeDeclNode( expr, @@ -108,16 +114,23 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => expr.toString(), astParentType, astParentFullName, - baseTypeFullName.getOrElse(TypeConstants.Object) :: Nil + baseTypeFullName.getOrElse(TypeConstants.Object) :: Nil, + genericSignature = Option(genericSignature) ) - typeFullName.foreach(scope.addInnerType(typeName, _)) + typeFullName.foreach(typeFullName => scope.addInnerType(typeName, typeFullName, typeFullName)) val declaredMethodNames = body.collect { case methodDeclaration: MethodDeclaration => methodDeclaration.getNameAsString }.toSet - scope.pushTypeDeclScope(typeDeclRoot, scope.isEnclosingScopeStatic, declaredMethodNames) + scope.pushTypeDeclScope( + typeDeclRoot, + scope.isEnclosingScopeStatic, + outerClassGenericSignature, + declaredMethodNames, + Nil + ) val memberAsts = astsForTypeDeclMembers(expr, body, isInterface = false, typeFullName) val localDecls = scope.localDeclsInScope @@ -157,11 +170,16 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val name = localClassDecl.getClassDeclaration.getNameAsString val enclosingMethodPrefix = scope.enclosingMethod.getMethodFullName.takeWhile(_ != ':') val fullName = s"$enclosingMethodPrefix.$name" - scope.addInnerType(name, fullName) - astForTypeDeclaration(localClassDecl.getClassDeclaration, fullNameOverride = Some(fullName)) + scope.addInnerType(name, fullName, fullName) + astForTypeDeclaration(localClassDecl.getClassDeclaration, fullNameOverride = Some(fullName), isLocalClass = true) } - def astForTypeDeclaration(typeDeclaration: TypeDeclaration[?], fullNameOverride: Option[String] = None): Ast = { + def astForTypeDeclaration( + typeDeclaration: TypeDeclaration[?], + fullNameOverride: Option[String] = None, + isLocalClass: Boolean = false + ): Ast = { + val isInterface = typeDeclaration match { case classDeclaration: ClassOrInterfaceDeclaration => classDeclaration.isInterface case _ => false @@ -172,16 +190,41 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val typeDeclRoot = createTypeDeclNode(typeDeclaration, astParentType, astParentFullName, isInterface, fullNameOverride) + // If this is a nested type (which must be true if an enclosing decl exists at this point), then the internal name + // of the class, e.g. Foo$Bar must be added to the scope to make lookups for type Bar possible. + scope.enclosingTypeDecl.foreach { _ => + if (!isLocalClass) + scope.addInnerType(typeDeclaration.getNameAsString, typeDeclRoot.fullName, typeDeclRoot.name) + } + val declaredMethodNames = typeDeclaration.getMethods.asScala.map(_.getNameAsString).toSet - scope.pushTypeDeclScope(typeDeclRoot, typeDeclaration.isStatic, declaredMethodNames) + + val recordParameters = typeDeclaration match { + case recordDeclaration: RecordDeclaration => recordDeclaration.getParameters.asScala.toList + case _ => Nil + } + + scope.pushTypeDeclScope( + typeDeclRoot, + typeDeclaration.isStatic, + outerClassGenericSignature, + declaredMethodNames, + recordParameters + ) addTypeDeclTypeParamsToScope(typeDeclaration) + val recordParameterAsts = typeDeclaration match { + case recordDeclaration: RecordDeclaration => astsForRecordParameters(recordDeclaration, typeDeclRoot.fullName) + case _ => Nil + } + val annotationAsts = typeDeclaration.getAnnotations.asScala.map(astForAnnotationExpr) val modifiers = modifiersForTypeDecl(typeDeclaration, isInterface) val enumEntries = typeDeclaration match { case enumDeclaration: EnumDeclaration => enumDeclaration.getEntries.asScala.toList case _ => Nil } + val memberAsts = astsForTypeDeclMembers( typeDeclaration, @@ -194,6 +237,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val lambdaMethods = scope.lambdaMethodsInScope val typeDeclAst = Ast(typeDeclRoot) + .withChildren(recordParameterAsts) .withChildren(memberAsts) .withChildren(annotationAsts) .withChildren(localDecls) @@ -228,6 +272,39 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => typeDeclAst } + private def astsForRecordParameters(recordDeclaration: RecordDeclaration, recordTypeFullName: String): List[Ast] = { + val explicitMethodNames = recordDeclaration.getMethods.asScala.map(_.getNameAsString).toSet + + recordDeclaration.getParameters.asScala.toList.flatMap { parameter => + val parameterName = parameter.getNameAsString + val parameterTypeFullName = tryWithSafeStackOverflow { + val typ = parameter.getType + scope + .lookupScopeType(typ.asString()) + .map(_.typeFullName) + .orElse(typeInfoCalc.fullName(typ)) + .getOrElse(defaultTypeFallback(typ)) + }.toOption.getOrElse(defaultTypeFallback()) + + val genericSignature = binarySignatureCalculator.variableBinarySignature(parameter.getType) + val parameterMember = memberNode( + parameter, + parameterName, + code(parameter), + parameterTypeFullName, + genericSignature = Option(genericSignature) + ) + val privateModifier = newModifierNode(ModifierTypes.PRIVATE) + val memberAst = Ast(parameterMember).withChild(Ast(privateModifier)) + + val accessorMethodAst = Option.unless(explicitMethodNames.contains(parameterName))( + astForRecordParameterAccessor(parameter, recordTypeFullName, parameterName, parameterTypeFullName) + ) + + memberAst :: accessorMethodAst.toList + } + } + private def bindingTypeForReferenceType(typ: ResolvedReferenceType): Option[JavaparserBindingDeclType] = { typ.getTypeDeclaration.toScala.map(typeDecl => scope.getDeclBinding(typeDecl.getName) match { @@ -290,7 +367,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => members.collect { case typeDeclaration: TypeDeclaration[_] => val (name, fullName) = getTypeDeclNameAndFullName(typeDeclaration, fullNameOverride) - scope.addInnerType(name, fullName) + scope.addInnerType(name, fullName, fullName) } val fields = members.collect { case fieldDeclaration: FieldDeclaration => fieldDeclaration } @@ -343,19 +420,35 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => } val constructorAstMap = astsForConstructors( - members.collect { case constructor: ConstructorDeclaration => - constructor + members.collect { + case constructor: ConstructorDeclaration => constructor + case constructor: CompactConstructorDeclaration => constructor }, instanceFields ) val membersAsts = membersAstPairs.flatMap { - case (constructor: ConstructorDeclaration, _) => - constructorAstMap.get(constructor) - case (_, asts) => asts + case (constructor: ConstructorDeclaration, _) => constructorAstMap.get(constructor) + case (constructor: CompactConstructorDeclaration, _) => constructorAstMap.get(constructor) + case (_, asts) => asts } - val defaultConstructorAst = Option.when(!(isInterface || members.exists(_.isInstanceOf[ConstructorDeclaration]))) { + val hasCanonicalConstructor = scope.enclosingTypeDecl.get.recordParameters match { + case Nil => members.exists(member => member.isConstructorDeclaration || member.isCompactConstructorDeclaration) + + case recordParameters => + members.collect { + case compactConstructorDeclaration: CompactConstructorDeclaration => compactConstructorDeclaration + + case constructorDeclaration: ConstructorDeclaration + if constructorDeclaration.getParameters.asScala + .map(_.getType) + .toList + .equals(recordParameters.map(_.getType)) => + constructorDeclaration + }.nonEmpty + } + val defaultConstructorAst = Option.when(!(isInterface || hasCanonicalConstructor)) { astForDefaultConstructor(originNode, instanceFields) } @@ -399,8 +492,10 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => receiverAst.root.foreach(receiver => diffGraph.addEdge(initRoot, receiver, EdgeTypes.RECEIVER)) val capturesAsts = - usedCaptures.filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map { - (usedCapture, index) => + usedCaptures + .filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass) + .zipWithIndex + .map { (usedCapture, index) => val identifier = NewIdentifier() .name(usedCapture.name) .code(usedCapture.name) @@ -408,10 +503,10 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => .lineNumber(initRoot.lineNumber) .columnNumber(initRoot.columnNumber) - diffGraph.addEdge(identifier, usedCapture.node, EdgeTypes.REF) + val refsTo = Option.when(usedCapture.name != NameConstants.OuterClass)(usedCapture.node) - Ast(identifier) - } + Ast(identifier).withRefEdges(identifier, refsTo.toList) + } (receiverAst :: args ++ outerClassAst.toList ++ capturesAsts) .map { argAst => @@ -437,13 +532,20 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => private def membersForCapturedVariables(originNode: Node, captures: List[ScopeVariable]): List[Ast] = { captures.map { variable => - val node = memberNode(originNode, variable.name, variable.name, variable.typeFullName) + val node = memberNode( + originNode, + variable.name, + variable.name, + variable.typeFullName, + genericSignature = Option(variable.genericSignature) + ) Ast(node) } } private def getStaticFieldInitializers(staticFields: List[FieldDeclaration]): List[Ast] = { - staticFields.flatMap { field => + scope.pushMethodScope(NewMethod(), ExpectedType.empty, isStatic = true) + val fieldsAsts = staticFields.flatMap { field => field.getVariables.asScala.toList.flatMap { variable => scope.pushFieldDeclScope(isStatic = true, name = variable.getNameAsString) val assignment = astsForVariableDeclarator(variable, field) @@ -451,11 +553,16 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => assignment } } + val methodScope = scope.popMethodScope() + methodScope.getTemporaryLocals.map(Ast(_)) ++ methodScope + .getUnaddedPatternVariableAstsAndMarkAdded() ++ fieldsAsts } private[declarations] def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = { - val fallbackType = s"${Defines.UnresolvedNamespace}.${annotationExpr.getNameAsString}" - val fullName = expressionReturnTypeFullName(annotationExpr).getOrElse(fallbackType) + val fullName = scope + .lookupType(annotationExpr.getNameAsString) + .orElse(tryWithSafeStackOverflow(annotationExpr.resolve()).toOption.flatMap(typeInfoCalc.fullName)) + .getOrElse(defaultTypeFallback(annotationExpr.getNameAsString)) typeInfoCalc.registerType(fullName) val code = annotationExpr.toString val name = annotationExpr.getName.getIdentifier @@ -580,9 +687,11 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val name = v.getName.toString // Use type name without generics stripped in code val variableTypeString = tryWithSafeStackOverflow(v.getTypeAsString).getOrElse("") - val node = memberNode(v, name, s"$variableTypeString $name", typeFullName) - val memberAst = Ast(node) - val annotationAsts = annotations.asScala.map(astForAnnotationExpr) + val genericSignature = binarySignatureCalculator.variableBinarySignature(v.getType) + val node = + memberNode(v, name, s"$variableTypeString $name", typeFullName, genericSignature = Option(genericSignature)) + val memberAst = Ast(node) + val annotationAsts = annotations.asScala.map(astForAnnotationExpr) val fieldDeclModifiers = modifiersForFieldDeclaration(fieldDeclaration) @@ -634,6 +743,10 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => Seq() } maybeJavaObjectType ++ inheritsFromTypeNames + } else if (typ.isEnumDeclaration) { + TypeConstants.Enum :: Nil + } else if (typ.isRecordDeclaration) { + TypeConstants.Record :: Nil } else { List.empty[String] } @@ -642,7 +755,18 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val code = codeForTypeDecl(typ, isInterface) - typeDeclNode(typ, name, fullName, filename, code, astParentType, astParentFullName, baseTypeFullNames) + val genericSignature = binarySignatureCalculator.typeDeclBinarySignature(typ) + typeDeclNode( + typ, + name, + fullName, + filename, + code, + astParentType, + astParentFullName, + baseTypeFullNames, + genericSignature = Option(genericSignature) + ) } private def codeForTypeDecl(typ: TypeDeclaration[?], isInterface: Boolean): String = { @@ -682,15 +806,21 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => } private def addTypeDeclTypeParamsToScope(typ: TypeDeclaration[?]): Unit = { - tryWithSafeStackOverflow(typ.resolve()).map(_.getTypeParameters.asScala) match { - case Success(resolvedTypeParams) => - resolvedTypeParams - .map(identifierForResolvedTypeParameter) - .foreach { typeParamIdentifier => - scope.addTopLevelType(typeParamIdentifier.name, typeParamIdentifier.typeFullName) - } - - case _ => // Nothing to do here + val typeParameters = typ match { + case classOrInterfaceDeclaration: ClassOrInterfaceDeclaration => + classOrInterfaceDeclaration.getTypeParameters.asScala + case recordDeclaration: RecordDeclaration => recordDeclaration.getTypeParameters.asScala + case _ => Nil + } + + typeParameters.foreach { case typeParam => + // TODO: Use typeParam.getTypeBound list to calculate this instead to allow better fallback. + val typeFullName = tryWithSafeStackOverflow(typeParam.resolve().asTypeParameter().getUpperBound).toOption + .flatMap(typeInfoCalc.fullName) + .getOrElse(TypeConstants.Object) + typeInfoCalc.registerType(typeFullName) + + scope.addTypeParameter(typeParam.getNameAsString, typeFullName) } } @@ -699,7 +829,14 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => val typeFullName = tryWithSafeStackOverflow(entry.resolve().getType).toOption.flatMap(typeInfoCalc.fullName) - val entryNode = memberNode(entry, entry.getNameAsString, entry.toString, typeFullName.getOrElse("ANY")) + val genericSignature = binarySignatureCalculator.enumEntryBinarySignature(entry) + val entryNode = memberNode( + entry, + entry.getNameAsString, + entry.toString, + typeFullName.getOrElse("ANY"), + genericSignature = Some(genericSignature) + ) val name = s"${typeFullName.getOrElse(Defines.UnresolvedNamespace)}.${Defines.ConstructorMethodName}" diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/BinarySignatureCalculator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/BinarySignatureCalculator.scala new file mode 100644 index 000000000000..34629db57ca6 --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/BinarySignatureCalculator.scala @@ -0,0 +1,327 @@ +package io.joern.javasrc2cpg.astcreation.declarations + +import com.github.javaparser.ast.`type`.{ + ArrayType, + ClassOrInterfaceType, + PrimitiveType, + Type, + TypeParameter, + UnknownType, + VarType, + VoidType, + WildcardType +} +import com.github.javaparser.ast.body.{ + AnnotationDeclaration, + CallableDeclaration, + ClassOrInterfaceDeclaration, + ConstructorDeclaration, + EnumConstantDeclaration, + EnumDeclaration, + MethodDeclaration, + Parameter, + RecordDeclaration, + TypeDeclaration +} +import com.github.javaparser.ast.expr.{LambdaExpr, TypePatternExpr} +import com.github.javaparser.printer.configuration.DefaultPrinterConfiguration.ConfigOption +import com.github.javaparser.printer.configuration.{DefaultConfigurationOption, DefaultPrinterConfiguration} +import io.joern.javasrc2cpg.astcreation.declarations.BinarySignatureCalculator.{ + BaseTypeMap, + javaEnumName, + javaObjectName, + javaRecordName, + unspecifiedType +} +import io.joern.javasrc2cpg.scope.Scope +import io.joern.javasrc2cpg.scope.Scope.ScopeTypeParam +import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator +import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants +import io.joern.javasrc2cpg.util.Util +import org.objectweb.asm.signature.SignatureWriter +import org.slf4j.LoggerFactory + +import scala.jdk.CollectionConverters.* +import scala.jdk.OptionConverters.RichOptional + +object BinarySignatureCalculator { + private val javaObjectName = "Object" + private val javaEnumName = "Enum" + private val javaRecordName = "Record" + private val unspecifiedType = "__unspecified_type" + + // From https://docs.oracle.com/javase/specs/jvms/se23/html/jvms-4.html#jvms-4.3 + val BaseTypeMap: Map[String, Char] = Seq( + TypeConstants.Byte -> 'B', + TypeConstants.Char -> 'C', + TypeConstants.Double -> 'D', + TypeConstants.Float -> 'F', + TypeConstants.Int -> 'I', + TypeConstants.Long -> 'J', + TypeConstants.Short -> 'S', + TypeConstants.Boolean -> 'Z', + TypeConstants.Void -> 'V' + ).toMap +} + +class BinarySignatureCalculator(scope: Scope) { + + private val logger = LoggerFactory.getLogger(this.getClass) + + private val typePrinterOptions = new DefaultPrinterConfiguration() + .removeOption(new DefaultConfigurationOption(ConfigOption.PRINT_COMMENTS)) + .removeOption(new DefaultConfigurationOption(ConfigOption.PRINT_JAVADOC)) + + private def typeToString(typ: Type): String = { + Util.stripGenericTypes(typ.toString(typePrinterOptions)) + } + + val unspecifiedClassType: String = { + val writer = SignatureWriter() + + writer.visitClassType(unspecifiedType) + writer.visitEnd() + + writer.toString + } + + def defaultConstructorSignature(parameters: List[Parameter]): String = { + val writer = SignatureWriter() + + parameters.foreach { param => + writer.visitParameterType() + addType(writer, param.getType) + } + + writer.visitReturnType() + addType(writer, new VoidType()) + + writer.toString + } + + def recordParameterAccessorBinarySignature(parameter: Parameter): String = { + val writer = SignatureWriter() + + writer.visitReturnType() + addType(writer, parameter.getType) + + writer.toString + } + + def enumEntryBinarySignature(enumEntry: EnumConstantDeclaration): String = { + val writer = SignatureWriter() + + enumEntry.getParentNode.toScala collect { case enumDeclaration: EnumDeclaration => + writer.visitClassType(enumDeclaration.getNameAsString) + writer.visitEnd() + } + + writer.toString + } + + def variableBinarySignature(typ: String): String = { + val writer = SignatureWriter() + + BaseTypeMap.get(typ) match { + case Some(baseType) => + writer.visitBaseType(baseType) + + case None => + writer.visitClassType(typ) + writer.visitEnd() + } + + writer.toString + } + + def typeDeclBinarySignature(typeDeclaration: TypeDeclaration[?]): String = { + typeDeclaration match { + case decl: AnnotationDeclaration => annotationDecBinarySignature(decl) + case decl: RecordDeclaration => recordDeclBinarySignature(decl) + case decl: ClassOrInterfaceDeclaration => classDeclBinarySignature(decl) + case decl: EnumDeclaration => enumDeclBinarySignature(decl) + case decl => + throw new IllegalArgumentException( + s"Attempting to get binary signature for unhandled type declaration $typeDeclaration" + ) + } + } + + def annotationDecBinarySignature(annotationDecl: AnnotationDeclaration): String = { + val writer = SignatureWriter() + + writer.visitClassType(javaObjectName) + writer.visitEnd() + + writer.toString + } + + def enumDeclBinarySignature(enumDecl: EnumDeclaration): String = { + val writer = SignatureWriter() + + writer.visitSuperclass() + writer.visitClassType(javaEnumName) + writer.visitTypeArgument('=') + writer.visitClassType(enumDecl.getNameAsString) + writer.visitEnd() + + enumDecl.getImplementedTypes.asScala.foreach(addType(writer, _)) + writer.visitEnd() + + writer.toString + } + + def patternVariableBinarySignature(typePatternExpr: TypePatternExpr): String = { + val writer = SignatureWriter() + addType(writer, typePatternExpr.getType) + writer.toString + } + + def classDeclBinarySignature( + classDecl: ClassOrInterfaceDeclaration, + classNameOverride: Option[String] = None + ): String = { + val writer = SignatureWriter() + classDecl.getTypeParameters.asScala.foreach(addTypeParam(writer, _)) + + writer.visitSuperclass() + if (classDecl.isInterface) { + writer.visitClassType(javaObjectName) + writer.visitEnd() + classDecl.getExtendedTypes.asScala.foreach(addType(writer, _)) + } else { + if (classDecl.getExtendedTypes.isEmpty) { + writer.visitClassType(javaObjectName) + writer.visitEnd() + } else { + classDecl.getExtendedTypes.asScala.foreach(addType(writer, _)) + } + classDecl.getImplementedTypes.asScala.foreach(addType(writer, _)) + } + + writer.toString + } + + def recordDeclBinarySignature(recordDecl: RecordDeclaration): String = { + val writer = SignatureWriter() + recordDecl.getTypeParameters.asScala.foreach(addTypeParam(writer, _)) + + writer.visitSuperclass() + writer.visitClassType(javaRecordName) + writer.visitEnd() + + recordDecl.getImplementedTypes.asScala.foreach(addType(writer, _)) + + writer.toString + } + + def variableBinarySignature(variableType: Type): String = { + val writer = SignatureWriter() + addType(writer, variableType) + writer.toString + } + + def methodBinarySignature(callableDecl: CallableDeclaration[?]): String = { + val writer = SignatureWriter() + callableDecl.getTypeParameters.asScala.foreach(addTypeParam(writer, _)) + + callableDecl.getParameters.asScala.foreach { param => + writer.visitParameterType() + addType(writer, param.getType) + } + + writer.visitReturnType() + callableDecl match { + case methodDeclaration: MethodDeclaration => addType(writer, methodDeclaration.getType) + case constructorDeclaration: ConstructorDeclaration => + BaseTypeMap.get(TypeConstants.Void).foreach(writer.visitBaseType(_)) + } + + callableDecl.getThrownExceptions.asScala.foreach { exception => + writer.visitExceptionType() + addType(writer, exception) + } + + writer.toString + } + + def lambdaMethodBinarySignature(expr: LambdaExpr): String = { + val writer = SignatureWriter() + + expr.getParameters.asScala.foreach { param => + writer.visitParameterType() + addType(writer, param.getType) + } + + writer.visitReturnType() + writer.visitClassType(unspecifiedType) + writer.visitEnd() + + writer.toString + } + + private def addTypeParam(writer: SignatureWriter, typeParam: TypeParameter): Unit = { + writer.visitFormalTypeParameter(typeParam.getNameAsString()) + writer.visitClassBound() + val typeBoundIt = typeParam.getTypeBound.asScala.iterator + if (typeBoundIt.isEmpty) { + writer.visitClassType(javaObjectName) + writer.visitEnd() + } else { + addType(writer, typeBoundIt.next) + } + typeBoundIt.foreach { typeBound => + writer.visitInterfaceBound() + addType(writer, typeBound) + } + } + + private def isTypeVariable(name: String): Boolean = { + scope.lookupScopeType(name, includeWildcards = false).exists(_.isInstanceOf[ScopeTypeParam]) + } + + private def addType(writer: SignatureWriter, typ: Type): Unit = { + val name = typeToString(typ) + typ match { + case _ if isTypeVariable(name) => + writer.visitTypeVariable(typeToString(typ)) + case classOrInterface: ClassOrInterfaceType => + val internalClassName = scope + .lookupScopeType(name, includeWildcards = false) + .map(_.name) + .getOrElse(name) + writer.visitClassType(internalClassName) + val typeArgs = classOrInterface.getTypeArguments.toScala.map(_.asScala).getOrElse(Nil) + typeArgs.foreach { + case wildcardType: WildcardType => + if (wildcardType.getExtendedType.isPresent) { + writer.visitTypeArgument('+') + addType(writer, wildcardType.getExtendedType.get()) + } else if (wildcardType.getSuperType.isPresent) { + writer.visitTypeArgument('-') + addType(writer, wildcardType.getSuperType.get()) + } else writer.visitTypeArgument('*') + case typeArg => + writer.visitTypeArgument('=') + addType(writer, typeArg) + } + writer.visitEnd() + case arrayType: ArrayType => + writer.visitArrayType() + addType(writer, arrayType.getElementType) + case typeParam: TypeParameter => + writer.visitTypeVariable(typeParam.getNameAsString) + case primitiveType: PrimitiveType => + writer.visitBaseType(primitiveType.getType.toDescriptor.charAt(0)) + case varType: VarType => + writer.visitClassType(unspecifiedType) + writer.visitEnd() + case _: VoidType => + writer.visitBaseType('V') + case _: UnknownType => + writer.visitClassType(unspecifiedType) + writer.visitEnd() + } + } + +} diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala index 88687b18fa29..9bd15461ff37 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala @@ -52,16 +52,14 @@ trait AstForCallExpressionsCreator { this: AstCreator => private val logger = LoggerFactory.getLogger(this.getClass) - private var tempConstCount = 0 - private[expressions] def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = { val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve()) val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments) val expressionTypeFullName = - expressionReturnTypeFullName(call).orElse(expectedReturnType.fullName).map(typeInfoCalc.registerType) + expressionReturnTypeFullName(call).orElse(getTypeFullName(expectedReturnType)).map(typeInfoCalc.registerType) - val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall) + val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall.toOption) val returnType = maybeResolvedCall .map { resolvedCall => typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty()) @@ -105,17 +103,17 @@ trait AstForCallExpressionsCreator { this: AstCreator => .dispatchType(dispatchType) .lineNumber(line(call)) .columnNumber(column(call)) - .typeFullName(expressionTypeFullName.getOrElse(TypeConstants.Any)) + .typeFullName(expressionTypeFullName.getOrElse(defaultTypeFallback())) callAst(callRoot, argumentAsts, scopeAsts.headOption) } private def astForImplicitCallReceiver(declaringType: Option[String], call: MethodCallExpr): Ast = { - val typeFullName = scope.lookupVariable(NameConstants.This).typeFullName.getOrElse(TypeConstants.Any) + val typeFullName = scope.lookupVariable(NameConstants.This).typeFullName.getOrElse(defaultTypeFallback()) val thisIdentifier = identifierNode(call, NameConstants.This, NameConstants.This, typeFullName) scope.lookupVariable(NameConstants.This) match { - case SimpleVariable(ScopeParameter(thisParam)) => diffGraph.addEdge(thisIdentifier, thisParam, EdgeTypes.REF) + case SimpleVariable(ScopeParameter(thisParam, _)) => diffGraph.addEdge(thisIdentifier, thisParam, EdgeTypes.REF) case _ => // Do nothing. This shouldn't happen for valid code, but could occur in cases where methods could not be resolved } val thisAst = Ast(thisIdentifier) @@ -143,17 +141,18 @@ trait AstForCallExpressionsCreator { this: AstCreator => } private[expressions] def blockAstForObjectCreationExpr(expr: ObjectCreationExpr, expectedType: ExpectedType): Ast = { - val tmpName = "$obj" ++ tempConstCount.toString - tempConstCount += 1 + val tmpName = tempNameProvider.next // Use an untyped identifier for receiver here, create the alloc and init ASTs, // then use the types of those to fix the local type. - val assignTarget = identifierNode(expr, tmpName, tmpName, TypeConstants.Any) + val assignTarget = identifierNode(expr, tmpName, tmpName, defaultTypeFallback()) val allocAndInitAst = inlinedAstsForObjectCreationExpr(expr, Ast(assignTarget.copy), expectedType, resetAssignmentTargetType = true) - assignTarget.typeFullName(allocAndInitAst.allocAst.rootType.getOrElse(TypeConstants.Any)) - val tmpLocal = localNode(expr, tmpName, tmpName, assignTarget.typeFullName) + assignTarget.typeFullName(allocAndInitAst.allocAst.rootType.getOrElse(defaultTypeFallback())) + val genericSignature = binarySignatureCalculator.variableBinarySignature(expr.getType) + val tmpLocal = + localNode(expr, tmpName, tmpName, assignTarget.typeFullName, genericSignature = Option(genericSignature)) val allocAssignCode = s"$tmpName = ${allocAndInitAst.allocAst.rootCodeOrEmpty}" val allocAssignCall = @@ -204,9 +203,12 @@ trait AstForCallExpressionsCreator { this: AstCreator => val baseTypeFullName = baseTypeFromScope .map(_.typeFullName) - .orElse( - tryWithSafeStackOverflow(typeInfoCalc.fullName(expr.getType)).toOption.flatten.orElse(expectedType.fullName) - ) + .orElse(tryWithSafeStackOverflow(expr.getType).toOption.map { typ => + typeInfoCalc + .fullName(typ) + .orElse(getTypeFullName(expectedType)) + .getOrElse(defaultTypeFallback(typ)) + }) val typeFullName = if (anonymousClassBody.isEmpty) baseTypeFullName.map(typeFullName => s"$typeFullName$nameSuffix") @@ -224,7 +226,7 @@ trait AstForCallExpressionsCreator { this: AstCreator => case Some(TypeConstants.Any) => typeFullName case Some(PropertyDefaults.TypeFullName) => typeFullName case Some(typ) => Option(typ) - case None => TypeConstants.Any + case None => defaultTypeFallback() } anonymousClassBody.foreach { bodyStmts => @@ -232,18 +234,18 @@ trait AstForCallExpressionsCreator { this: AstCreator => scope.addLocalDecl(anonymousClassDecl) } - val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr) + val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr.toOption) val allocNode = newOperatorCallNode( Operators.alloc, expr.toString, - typeFullName.orElse(Some(TypeConstants.Any)), + typeFullName.orElse(Some(defaultTypeFallback())), line(expr), column(expr) ) val initCall = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), + typeFullName.orElse(Some(defaultTypeFallback())), argumentTypes, argumentAsts.size, expr.toString, @@ -294,7 +296,7 @@ trait AstForCallExpressionsCreator { this: AstCreator => tryResolvedDecl match { case Success(_) if hasVariadicParameter => - val expectedVariadicTypeFullName = getExpectedParamType(tryResolvedDecl, paramCount - 1).fullName + val expectedVariadicTypeFullName = getTypeFullName(getExpectedParamType(tryResolvedDecl, paramCount - 1)) val (regularArgs, varargs) = argsAsts.splitAt(paramCount - 1) val arrayInitializer = newOperatorCallNode( Operators.arrayInitializer, diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForExpressionsCreator.scala index 4bc53c567cae..5ee657cc3eaa 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForExpressionsCreator.scala @@ -33,6 +33,7 @@ trait AstForExpressionsCreator with AstForLambdasCreator with AstForCallExpressionsCreator with AstForNameExpressionsCreator + with AstForPatternExpressionsCreator with AstForVarDeclAndAssignsCreator { this: AstCreator => def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = { // TODO: Implement missing handlers diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForLambdasCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForLambdasCreator.scala index f38c1d958779..98df127f3e47 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForLambdasCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForLambdasCreator.scala @@ -17,13 +17,12 @@ import io.joern.javasrc2cpg.util.BindingTable.createBindingTable import io.joern.javasrc2cpg.util.Util.{composeMethodFullName, composeMethodLikeSignature, composeUnresolvedSignature} import io.joern.javasrc2cpg.util.{BindingTable, BindingTableAdapterForLambdas, LambdaBindingInfo, NameConstants} import io.joern.x2cpg.utils.AstPropertiesUtil.* -import io.joern.x2cpg.utils.NodeBuilders +import io.joern.x2cpg.utils.{IntervalKeyPool, NodeBuilders} import io.joern.x2cpg.utils.NodeBuilders.* import io.joern.x2cpg.{Ast, Defines} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn.PropertyDefaults as ParameterDefaults import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies, ModifierTypes} -import io.shiftleft.passes.IntervalKeyPool import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.* @@ -55,6 +54,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => variablesInScope: Seq[ScopeVariable], expectedLambdaType: ExpectedType ): (NewMethod, LambdaBody) = { + val implementedMethod = implementedInfo.implementedMethod val implementedInterface = implementedInfo.implementedInterface @@ -62,17 +62,21 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => // symbol solver returns the erased types when resolving the lambda itself. val expectedTypeParamTypes = genericParamTypeMapForLambda(expectedLambdaType) val parametersWithoutThis = buildParamListForLambda(expr, implementedMethod, expectedTypeParamTypes) + val returnType = getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) + + val lambdaMethodNode = createLambdaMethodNode(expr, lambdaMethodName, parametersWithoutThis, returnType) - val returnType = getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) + // TODO: lambda method scope can be static if no non-static captures are used + scope.pushMethodScope(lambdaMethodNode, expectedLambdaType, isStatic = false) - val lambdaBody = astForLambdaBody(lambdaMethodName, expr.getBody, variablesInScope, returnType) + val lambdaBody = astForLambdaBody(expr, lambdaMethodName, expr.getBody, variablesInScope, returnType) val thisParam = lambdaBody.nodes .collect { case identifier: NewIdentifier => identifier } .find { identifier => identifier.name == NameConstants.This || identifier.name == NameConstants.Super } .map { _ => val typeFullName = scope.enclosingTypeDecl.fullName - Ast(thisNodeForMethod(typeFullName, line(expr))) + Ast(thisNodeForMethod(typeFullName, line(expr), column(expr))) } .toList @@ -89,9 +93,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => .collect { case identifier: NewIdentifier => identifier } .filter { identifier => lambdaParameterNamesToNodes.contains(identifier.name) } - val lambdaMethodNode = createLambdaMethodNode(expr, lambdaMethodName, parametersWithoutThis, returnType) - - val returnNode = newMethodReturnNode(returnType.getOrElse(TypeConstants.Any), None, line(expr), column(expr)) + val returnNode = newMethodReturnNode(returnType.getOrElse(defaultTypeFallback()), None, line(expr), column(expr)) val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) val staticModifier = Option.when(thisParam.isEmpty)(newModifierNode(ModifierTypes.STATIC)) val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) @@ -110,6 +112,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => ast.withRefEdge(identifier, lambdaParameterNamesToNodes(identifier.name)) ) + scope.popMethodScope() scope.addLambdaMethod(lambdaMethodAst) lambdaMethodNode -> lambdaBody @@ -120,8 +123,8 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => val containsEmptyType = maybeParameterTypes.exists(_.contains(ParameterDefaults.TypeFullName)) (returnType, maybeParameterTypes) match { - case (Some(returnTpe), Some(parameterTpes)) if !containsEmptyType => - composeMethodLikeSignature(returnTpe, parameterTpes) + case (Some(returnType), Some(parameterTypes)) if !containsEmptyType => + composeMethodLikeSignature(returnType, parameterTypes) case _ => composeUnresolvedSignature(parameters.size) } @@ -137,7 +140,16 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => val signature = lambdaMethodSignature(returnType, parameters) val lambdaFullName = composeMethodFullName(enclosingTypeName, lambdaName, signature) - methodNode(lambdaExpr, lambdaName, "", lambdaFullName, Some(signature), filename) + val genericSignature = binarySignatureCalculator.lambdaMethodBinarySignature(lambdaExpr) + methodNode( + lambdaExpr, + lambdaName, + "", + lambdaFullName, + Some(signature), + filename, + genericSignature = Option(genericSignature) + ) } private def createAndPushLambdaTypeDecl( @@ -156,6 +168,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => .fullName(lambdaMethodNode.fullName) .name(lambdaMethodNode.name) .inheritsFromTypeFullName(inheritsFromTypeFullName) + .genericSignature(binarySignatureCalculator.unspecifiedClassType) scope.addLocalDecl(Ast(lambdaTypeDeclNode)) lambdaTypeDeclNode @@ -208,8 +221,6 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => } def astForLambdaExpr(expr: LambdaExpr, expectedType: ExpectedType): Ast = { - // TODO: lambda method scope can be static if no non-static captures are used - scope.pushMethodScope(NewMethod(), expectedType, isStatic = false) val lambdaMethodName = nextClosureName() @@ -238,7 +249,6 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => val lambdaTypeDeclNode = createAndPushLambdaTypeDecl(lambdaMethodNode, implementedInfo) BindingTable.createBindingNodes(diffGraph, lambdaTypeDeclNode, bindingTable) - scope.popMethodScope() Ast(methodRef) } @@ -257,6 +267,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => } private def defineCapturedVariables( + lambdaNode: LambdaExpr, lambdaMethodName: String, capturedVariables: Seq[ScopeVariable] ): Seq[(ClosureBindingEntry, NewLocal)] = { @@ -267,7 +278,14 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => val closureBindingNode = newClosureBindingNode(closureBindingId, name, EvaluationStrategies.BY_SHARING) val scopeVariable = variables.head - val capturedLocal = newLocalNode(scopeVariable.name, scopeVariable.typeFullName, Option(closureBindingId)) + val capturedLocal = localNode( + lambdaNode, + scopeVariable.name, + scopeVariable.name, + scopeVariable.typeFullName, + Option(closureBindingId), + Option(scopeVariable.genericSignature) + ) scope.enclosingBlock.foreach(_.addLocal(capturedLocal)) ClosureBindingEntry(scopeVariable, closureBindingNode) -> capturedLocal @@ -276,6 +294,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => } private def astForLambdaBody( + lambdaExpr: LambdaExpr, lambdaMethodName: String, body: Statement, variablesInScope: Seq[ScopeVariable], @@ -300,22 +319,16 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => stmts.flatMap(_.nodes).collect { case i: NewIdentifier if outerScopeVariableNames.contains(i.name) => outerScopeVariableNames(i.name) } - val bindingsToLocals = defineCapturedVariables(lambdaMethodName, capturedVariables) + val bindingsToLocals = defineCapturedVariables(lambdaExpr, lambdaMethodName, capturedVariables) val capturedLocalAsts = bindingsToLocals.map(_._2).map(Ast(_)) val closureBindingEntries = bindingsToLocals.map(_._1) + val temporaryLocalAsts = scope.enclosingMethod.map(_.getTemporaryLocals).getOrElse(Nil).map(Ast(_)) - body match { - case block: BlockStmt => - val blockAst = Ast(blockNode(block)) - .withChildren(capturedLocalAsts) - .withChildren(stmts) - LambdaBody(blockAst, closureBindingEntries) - case stmt => - val blockAst = Ast(blockNode(stmt)) - .withChildren(capturedLocalAsts) - .withChildren(stmts) - LambdaBody(blockAst, closureBindingEntries) - } + val blockAst = Ast(blockNode(body)) + .withChildren(temporaryLocalAsts) + .withChildren(capturedLocalAsts) + .withChildren(stmts) + LambdaBody(blockAst, closureBindingEntries) } private def genericParamTypeMapForLambda(expectedType: ExpectedType): ResolvedTypeParametersMap = { @@ -364,7 +377,7 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => .zipWithIndex .map { case ((param, maybeType), idx) => val name = param.getNameAsString - val typeFullName = maybeType.getOrElse(TypeConstants.Any) + val typeFullName = maybeType.getOrElse(defaultTypeFallback()) val code = s"$typeFullName $name" val evalStrat = if (tryWithSafeStackOverflow(param.getType).toOption.exists(_.isPrimitiveType)) EvaluationStrategies.BY_VALUE @@ -383,7 +396,8 @@ private[expressions] trait AstForLambdasCreator { this: AstCreator => } parameterNodes.foreach { paramNode => - scope.enclosingMethod.get.addParameter(paramNode) + scope.enclosingMethod.get + .addParameter(paramNode, binarySignatureCalculator.unspecifiedClassType) } parameterNodes.map(Ast(_)) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForNameExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForNameExpressionsCreator.scala index 8b8ad10c6a34..bd22bb96a917 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForNameExpressionsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForNameExpressionsCreator.scala @@ -1,12 +1,20 @@ package io.joern.javasrc2cpg.astcreation.expressions +import com.github.javaparser.ast.Node import com.github.javaparser.ast.expr.NameExpr import com.github.javaparser.resolution.declarations.ResolvedFieldDeclaration import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} +import io.joern.javasrc2cpg.scope.JavaScopeElement.TypeDeclScope import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.javasrc2cpg.util.NameConstants import io.joern.x2cpg.{Ast, Defines} -import io.shiftleft.codepropertygraph.generated.nodes.{NewLocal, NewMethodParameterIn} +import io.shiftleft.codepropertygraph.generated.nodes.{ + NewLocal, + NewMethodParameterIn, + NewTypeDecl, + NewTypeRef, + NewUnknown +} import scala.util.Success import io.joern.javasrc2cpg.scope.Scope.{ @@ -14,15 +22,15 @@ import io.joern.javasrc2cpg.scope.Scope.{ NotInScope, ScopeMember, ScopeParameter, + ScopePatternVariable, ScopeVariable, SimpleVariable } -import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import org.slf4j.LoggerFactory -import io.shiftleft.codepropertygraph.generated.nodes.NewUnknown import io.joern.x2cpg.utils.AstPropertiesUtil.* import io.shiftleft.codepropertygraph.generated.Operators -import io.joern.x2cpg.utils.NodeBuilders.newOperatorCallNode +import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode} +import io.joern.javasrc2cpg.scope.PatternVariableInfo trait AstForNameExpressionsCreator { this: AstCreator => @@ -31,7 +39,7 @@ trait AstForNameExpressionsCreator { this: AstCreator => private[expressions] def astForNameExpr(nameExpr: NameExpr, expectedType: ExpectedType): Ast = { val name = nameExpr.getName.toString val typeFullName = expressionReturnTypeFullName(nameExpr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) scope.lookupVariable(name) match { @@ -39,19 +47,27 @@ trait AstForNameExpressionsCreator { this: AstCreator => astForStaticImportOrUnknown(nameExpr, name, typeFullName) case SimpleVariable(variable: ScopeMember) => - val targetName = - Option.when(variable.isStatic)(scope.enclosingTypeDecl.fullName).flatten.getOrElse(NameConstants.This) - fieldAccessAst( - targetName, - scope.enclosingTypeDecl.fullName, + createImplicitBaseFieldAccess( + variable.isStatic, + scope.enclosingTypeDecl.name.get, + scope.enclosingTypeDecl.fullName.get, + nameExpr, variable.name, - Some(variable.typeFullName), - line(nameExpr), - column(nameExpr) + variable.typeFullName ) + case SimpleVariable(ScopePatternVariable(localNode, typePatternExpr)) => + scope.enclosingMethod.flatMap(_.getPatternVariableInfo(typePatternExpr)) match { + case Some(PatternVariableInfo(typePatternExpr, _, initializerAst, _, false, _)) => + scope.enclosingMethod.foreach(_.registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr)) + initializerAst + case _ => + val identifier = identifierNode(nameExpr, localNode.name, localNode.name, localNode.typeFullName) + Ast(identifier).withRefEdge(identifier, localNode) + } + case SimpleVariable(variable) => - val identifier = identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any)) + val identifier = identifierNode(nameExpr, name, name, typeFullName.getOrElse(defaultTypeFallback())) val captured = variable.node match { case param: NewMethodParameterIn => Some(param) case local: NewLocal => Some(local) @@ -65,31 +81,51 @@ trait AstForNameExpressionsCreator { this: AstCreator => } } + private[expressions] def createImplicitBaseFieldAccess( + isStatic: Boolean, + baseTypeDeclName: String, + baseTypeDeclFullName: String, + node: Node, + fieldName: String, + fieldTypeFullName: String + ): Ast = { + val base = + if (isStatic) { + NewTypeRef() + .code(baseTypeDeclName) + .typeFullName(baseTypeDeclFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) + } else { + newIdentifierNode(NameConstants.This, baseTypeDeclFullName) + } + fieldAccessAst( + Ast(base), + s"${base.code}.$fieldName", + line(node), + column(node), + fieldName, + fieldTypeFullName, + line(node), + column(node) + ) + } + private def astForStaticImportOrUnknown(nameExpr: NameExpr, name: String, typeFullName: Option[String]): Ast = { tryWithSafeStackOverflow(nameExpr.resolve()) match { - case Success(value) if value.isField => - val identifierName = if (value.asField.isStatic) { - // TODO: v is wrong. Statically imported expressions can also be represented by just the name. - // A static field represented by a NameExpr must belong to the class in which it's used. Static fields - // from other classes are represented by a FieldAccessExpr instead. - scope.enclosingTypeDecl.map(_.typeDecl.name).getOrElse(s"${Defines.UnresolvedNamespace}.$name") - } else { - NameConstants.This - } - - val identifierTypeFullName = - value match { - case fieldDecl: ResolvedFieldDeclaration => - // TODO It is not quite correct to use the declaring classes type. - // Instead we should take the using classes type which is either the same or a - // sub class of the declaring class. - typeInfoCalc.fullName(fieldDecl.declaringType()) - } - - fieldAccessAst(identifierName, identifierTypeFullName, name, typeFullName, line(nameExpr), column(nameExpr)) + case Success(value: ResolvedFieldDeclaration) => + // TODO using the enclosingTypeDecl is wrong if the field was imported via a static import. + createImplicitBaseFieldAccess( + value.asField().isStatic, + typeInfoCalc.name(value.declaringType()).getOrElse(defaultTypeFallback()), + typeInfoCalc.fullName(value.declaringType()).getOrElse(defaultTypeFallback()), + nameExpr, + name, + typeFullName.getOrElse(defaultTypeFallback()) + ) case _ => - Ast(identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any))) + Ast(identifierNode(nameExpr, name, name, typeFullName.getOrElse(defaultTypeFallback()))) } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForPatternExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForPatternExpressionsCreator.scala new file mode 100644 index 000000000000..105ee7a2cc85 --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForPatternExpressionsCreator.scala @@ -0,0 +1,387 @@ +package io.joern.javasrc2cpg.astcreation.expressions + +import com.github.javaparser.ast.Node +import com.github.javaparser.ast.expr.{PatternExpr, RecordPatternExpr, TypePatternExpr} +import io.joern.javasrc2cpg.astcreation.AstCreator +import io.joern.javasrc2cpg.jartypereader.model.Model.TypeConstants +import io.joern.javasrc2cpg.scope.Scope.NewVariableNode +import io.joern.javasrc2cpg.util.Util +import io.joern.x2cpg.{Ast, Defines} +import io.joern.x2cpg.utils.AstPropertiesUtil.* +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier} +import io.joern.x2cpg.utils.NodeBuilders.* +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import org.slf4j.LoggerFactory + +import java.util +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.jdk.OptionConverters.RichOptional + +class PatternInitAndRefAsts(private val initAst: Ast, private val refAst: Ast) { + private var getCount: Int = -1 + def get: Ast = { + getCount += 1 + getCount match { + case 0 => initAst + case 1 => refAst + case _ => refAst.subTreeCopy(refAst.root.get.asInstanceOf[AstNodeNew]) + } + } + + def rootType: Option[String] = initAst.rootType + + def asTuple: (Ast, Ast) = (initAst, refAst) +} + +object PatternInitAndRefAsts { + def apply(initAst: Ast, refAst: Ast): PatternInitAndRefAsts = new PatternInitAndRefAsts(initAst, refAst) + + def apply(initAst: Ast): PatternInitAndRefAsts = + new PatternInitAndRefAsts(initAst, initAst.subTreeCopy(initAst.root.get.asInstanceOf[AstNodeNew])) +} + +trait AstForPatternExpressionsCreator { this: AstCreator => + + private val logger = LoggerFactory.getLogger(this.getClass) + + trait PatternInitTreeNode(val patternExpr: PatternExpr) { + def getAst: Ast + + def typeFullName: Option[String] + } + + class PatternInitRoot(patternExpr: PatternExpr, ast: PatternInitAndRefAsts) extends PatternInitTreeNode(patternExpr) { + override def getAst: Ast = ast.get + + override def typeFullName: Option[String] = ast.rootType + } + + class PatternInitNode( + parentNode: PatternInitTreeNode, + patternExpr: PatternExpr, + fieldName: String, + fieldTypeFullName: Option[String], + requiresTemporaryVariable: Boolean + ) extends PatternInitTreeNode(patternExpr) { + private var cachedResult: Option[PatternInitAndRefAsts] = None + + override def typeFullName: Option[String] = fieldTypeFullName + + override def getAst: Ast = { + cachedResult.map(_.get).getOrElse { + val parentAst = parentNode.getAst + val patternTypeFullName = tryWithSafeStackOverflow(patternExpr.getType).toOption + .map { typ => + scope + .lookupScopeType(typ.asString()) + .map(_.typeFullName) + .orElse(typeInfoCalc.fullName(typ)) + .getOrElse(defaultTypeFallback(typ)) + } + .getOrElse(defaultTypeFallback()) + + val parentPatternType = getPatternTypeFullName(parentNode.patternExpr) + val lhsAst = castAstIfNecessary(parentNode.patternExpr, parentPatternType, parentAst) + + val signature = composeSignature(fieldTypeFullName, Option(Nil), 0) + val typeDeclFullName = + if (isResolvedTypeFullName(parentPatternType)) + parentPatternType + else + s"${Defines.UnresolvedNamespace}.${code(parentNode.patternExpr.getType)}" + val methodFullName = Util.composeMethodFullName(typeDeclFullName, fieldName, signature) + val methodCodePrefix = lhsAst.root match { + case Some(call: NewCall) if call.name.startsWith(" s"(${call.code})" + case Some(root: AstNodeNew) => root.code + case _ => "" + + } + val methodCode = s"$methodCodePrefix.$fieldName()" + + val fieldAccessorCall = callNode( + patternExpr, + methodCode, + fieldName, + methodFullName, + DispatchTypes.DYNAMIC_DISPATCH, + Option(signature), + fieldTypeFullName.orElse(Option(defaultTypeFallback())) + ) + + val fieldAccessorAst = callAst(fieldAccessorCall, Nil, Option(lhsAst)) + + val patternInitWithRef = if (requiresTemporaryVariable) { + val patternInitWithRef = initAndRefAstsForPatternInitializer(patternExpr, fieldAccessorAst) + patternInitWithRef + } else { + PatternInitAndRefAsts(fieldAccessorAst) + } + + cachedResult = Option(patternInitWithRef) + + patternInitWithRef.get + } + } + } + + /** In the lowering for instanceof expressions with patterns like `X instanceof Foo f`, the first argument to + * `instanceof` (in this case `X`) appears in the CPG at least 2 times: + * - once for the `X instanceof Foo` check + * - once for the `Foo f = (Foo) X` assignment. + * + * If X is an identifier or field access, then this is fine. If X is a call which could have side-effects, however, + * then this representation could lead to incorrect behaviour. + * + * This method solves this problem by taking the CPG lowering for X as input and returning a PatternInitAndRefAsts + * object. The first time `get` is called on one of these, the init AST is return. Every future get call returns the + * reference AST, ensuring that the variable is initialized exactly once + */ + private[astcreation] def initAndRefAstsForPatternInitializer( + rootNode: Node, + patternInitAst: Ast + ): PatternInitAndRefAsts = { + patternInitAst.root match { + case Some(identifier: NewIdentifier) => + PatternInitAndRefAsts(patternInitAst, patternInitAst.subTreeCopy(identifier)) + + case Some(fieldAccess: NewCall) if fieldAccess.name == Operators.fieldAccess => + PatternInitAndRefAsts( + patternInitAst, + patternInitAst.subTreeCopy(patternInitAst.root.get.asInstanceOf[AstNodeNew]) + ) + + case _ => + val tmpName = tempNameProvider.next + val tmpType = patternInitAst.rootType.getOrElse(TypeConstants.Object) + val tmpLocal = localNode( + rootNode, + tmpName, + tmpName, + tmpType, + genericSignature = Option(binarySignatureCalculator.unspecifiedClassType) + ) + val tmpIdentifier = identifierNode(rootNode, tmpName, tmpName, tmpType) + + val tmpAssignmentNode = + newOperatorCallNode( + Operators.assignment, + s"$tmpName = ${patternInitAst.rootCodeOrEmpty}", + Option(tmpType), + line(rootNode), + column(rootNode) + ) + + // Don't need to add the local to the block scope since the only identifiers referencing it are created here + // (so a lookup for the local will never be done) + scope.enclosingMethod.foreach(_.addTemporaryLocal(tmpLocal)) + + val initAst = + callAst(tmpAssignmentNode, Ast(tmpIdentifier) :: patternInitAst :: Nil).withRefEdge(tmpIdentifier, tmpLocal) + + val tmpIdentifierCopy = tmpIdentifier.copy + val referenceAst = Ast(tmpIdentifierCopy).withRefEdge(tmpIdentifierCopy, tmpLocal) + + PatternInitAndRefAsts(initAst, referenceAst) + } + } + + private def castAstIfNecessary(patternExpr: PatternExpr, patternType: String, initializerAst: Ast): Ast = { + val initializerType = initializerAst.rootType + if (isResolvedTypeFullName(patternType) && initializerType.contains(patternType)) { + initializerAst + } else { + val castType = typeRefNode(patternExpr, code(patternExpr.getType), patternType) + val castNode = + newOperatorCallNode( + Operators.cast, + s"(${castType.code}) ${initializerAst.rootCodeOrEmpty}", + Option(patternType), + line(patternExpr), + column(patternExpr) + ) + callAst(castNode, Ast(castType) :: initializerAst :: Nil) + } + } + + private def createAndPushAssignmentForTypePattern(patternNode: PatternInitTreeNode): Unit = { + patternNode.patternExpr match { + case recordPatternExpr: RecordPatternExpr => + logger.warn(s"Attempting to create assignment for record pattern expr ${code(recordPatternExpr)}") + + case typePatternExpr: TypePatternExpr => + val variableName = typePatternExpr.getNameAsString + val variableType = { + tryWithSafeStackOverflow(typePatternExpr.getType).toOption + .map(typ => + scope + .lookupScopeType(typ.asString()) + .map(_.typeFullName) + .orElse(typeInfoCalc.fullName(typ)) + .getOrElse(defaultTypeFallback(typ)) + ) + .getOrElse(defaultTypeFallback()) + } + val variableTypeCode = tryWithSafeStackOverflow(code(typePatternExpr.getType)).getOrElse(variableType) + val genericSignature = binarySignatureCalculator.variableBinarySignature(typePatternExpr.getType) + val patternLocal = localNode( + typePatternExpr, + variableName, + code(typePatternExpr), + variableType, + genericSignature = Option(genericSignature) + ) + val patternIdentifier = identifierNode(typePatternExpr, variableName, variableName, variableType) + + val initializerAst = castAstIfNecessary(typePatternExpr, variableType, patternNode.getAst) + + val initializerAssignmentCall = newOperatorCallNode( + Operators.assignment, + s"$variableName = ${initializerAst.rootCodeOrEmpty}", + Option(variableType), + line(typePatternExpr), + column(typePatternExpr) + ) + val initializerAssignmentAst = + callAst(initializerAssignmentCall, Ast(patternIdentifier) :: initializerAst :: Nil) + .withRefEdge(patternIdentifier, patternLocal) + + scope.enclosingMethod.foreach { methodScope => + methodScope.putPatternVariableInfo(typePatternExpr, patternLocal, initializerAssignmentAst) + } + + } + } + + private[astcreation] def instanceOfAstForPattern(patternExpr: PatternExpr, lhsAst: Ast): Ast = { + val patternTreeNode = PatternInitRoot(patternExpr, initAndRefAstsForPatternInitializer(patternExpr, lhsAst)) + val typePatternBuffer: mutable.ListBuffer[PatternInitTreeNode] = mutable.ListBuffer() + if (patternExpr.isTypePatternExpr) { + typePatternBuffer.append(patternTreeNode) + } + val typeCheckAst = typeCheckAstForPattern(patternExpr, patternTreeNode, typePatternBuffer).get + + typePatternBuffer.foreach(createAndPushAssignmentForTypePattern) + typeCheckAst + } + + private def typeCheckAstForPattern( + patternExpr: PatternExpr, + parentInitNode: PatternInitTreeNode, + typePatternBuffer: mutable.ListBuffer[PatternInitTreeNode] + ): Option[Ast] = { + val patternTypeFullName = getPatternTypeFullName(patternExpr) + + val isInstanceOfRequired = + parentInitNode.isInstanceOf[PatternInitRoot] + || !isResolvedTypeFullName(patternTypeFullName) + || !parentInitNode.typeFullName.contains(patternTypeFullName) + + val instanceOfAst = + Option.when(isInstanceOfRequired)(buildInstanceOfAst(patternExpr, parentInitNode, patternTypeFullName)) + + patternExpr match { + case typePatternExpr: TypePatternExpr => + instanceOfAst + + case recordPatternExpr: RecordPatternExpr => + val fieldAccessorInits = + initNodesForRecordFieldAccessors(recordPatternExpr, patternTypeFullName, parentInitNode) + + val fieldInstanceOfAsts = fieldAccessorInits.flatMap { fieldInitNode => + if (fieldInitNode.patternExpr.isTypePatternExpr) { + typePatternBuffer.append(fieldInitNode) + } + typeCheckAstForPattern(fieldInitNode.patternExpr, fieldInitNode, typePatternBuffer).map { ast => + (fieldInitNode.patternExpr, ast) + } + } + + (instanceOfAst.map(ast => (recordPatternExpr, ast)).toList ++ fieldInstanceOfAsts).reverse match { + case Nil => None + + case accumulator :: rest => + val result = rest.foldLeft(accumulator._2) { case (accumulatorAst, (childPattern, astToAdd)) => + val andNode = newOperatorCallNode( + Operators.logicalAnd, + s"(${astToAdd.rootCodeOrEmpty}) && (${accumulatorAst.rootCodeOrEmpty})", + Option(TypeConstants.Boolean), + line(childPattern), + column(childPattern) + ) + + callAst(andNode, astToAdd :: accumulatorAst :: Nil) + } + Option(result) + } + } + } + + private def initNodesForRecordFieldAccessors( + recordPatternExpr: RecordPatternExpr, + recordTypeFullName: String, + parentInitNode: PatternInitTreeNode + ): List[PatternInitNode] = { + val resolvedRecordType = tryWithSafeStackOverflow(recordPatternExpr.getType().resolve().asReferenceType()).toOption + + val patternList = recordPatternExpr.getPatternList.asScala.toList + val fieldNames = resolvedRecordType + .flatMap(_.getTypeDeclaration.toScala) + .map(_.getDeclaredFields.asScala.map(_.getName).toList) + .getOrElse(patternList.map(_ => Defines.UnknownField)) + + patternList.zip(fieldNames).map { case (childPatternExpr, fieldName) => + val childTypeFullName = getPatternTypeFullName(childPatternExpr) match { + case typeFullName if isResolvedTypeFullName(typeFullName) => Option(typeFullName) + case _ => None + } + + val fieldTypeFullName = resolvedRecordType + .flatMap(_.getTypeDeclaration.toScala) + .flatMap(typeDecl => tryWithSafeStackOverflow(typeDecl.getField(fieldName).getType).toOption) + .flatMap(typeInfoCalc.fullName) + + val childIsBranchingNode = + childPatternExpr.isRecordPatternExpr && childPatternExpr.asRecordPatternExpr().getPatternList.size() > 1 + val childTypeIsResolved = childTypeFullName.exists(isResolvedTypeFullName) + val requiresTemporaryVariable = + childIsBranchingNode || !childTypeIsResolved || childTypeFullName != fieldTypeFullName + + PatternInitNode(parentInitNode, childPatternExpr, fieldName, fieldTypeFullName, requiresTemporaryVariable) + } + } + + private def buildInstanceOfAst( + patternExpr: PatternExpr, + parentInitNode: PatternInitTreeNode, + patternTypeFullName: String + ): Ast = { + val patternTypeRef = typeRefNode(patternExpr.getType, code(patternExpr.getType), patternTypeFullName) + val initializerAst = parentInitNode.getAst + + val lhsCode = initializerAst.root match { + case Some(identifier: NewIdentifier) => identifier.code + case Some(call: NewCall) if call.name == Operators.fieldAccess => call.code + case Some(astNodeNew: AstNodeNew) => s"(${astNodeNew.code})" + case _ => "" + } + val instanceOfCall = newOperatorCallNode( + Operators.instanceOf, + s"$lhsCode instanceof ${code(patternExpr.getType)}", + Option(TypeConstants.Boolean) + ) + callAst(instanceOfCall, initializerAst :: Ast(patternTypeRef) :: Nil) + } + + private def getPatternTypeFullName(patternExpr: PatternExpr): String = { + tryWithSafeStackOverflow(patternExpr.getType).toOption + .map(typ => + scope + .lookupScopeType(typ.asString()) + .map(_.typeFullName) + .orElse(typeInfoCalc.fullName(typ)) + .getOrElse(defaultTypeFallback(typ)) + ) + .getOrElse(defaultTypeFallback()) + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForSimpleExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForSimpleExpressionsCreator.scala index 7982e905e43e..89649d371bc5 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForSimpleExpressionsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForSimpleExpressionsCreator.scala @@ -21,14 +21,18 @@ import com.github.javaparser.ast.expr.{ UnaryExpr } import com.github.javaparser.ast.nodeTypes.NodeWithName +import com.github.javaparser.ast.visitor.NodeFinderVisitor +import com.github.javaparser.symbolsolver.javaparsermodel.contexts.{BinaryExprContext, ConditionalExprContext} +import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} +import io.joern.javasrc2cpg.scope.PatternVariableInfo import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.javasrc2cpg.util.{NameConstants, Util} import io.joern.x2cpg.{Ast, Defines} import io.joern.x2cpg.utils.AstPropertiesUtil.* import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode} -import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewFieldIdentifier, NewLiteral, NewTypeRef} +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewFieldIdentifier, NewLiteral, NewTypeRef} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators} import scala.jdk.CollectionConverters.* @@ -40,9 +44,9 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForArrayAccessExpr(expr: ArrayAccessExpr, expectedType: ExpectedType): Ast = { val typeFullName = expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback()) val callNode = newOperatorCallNode( Operators.indexAccess, code = expr.toString, @@ -59,7 +63,11 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => } private[expressions] def astForArrayCreationExpr(expr: ArrayCreationExpr, expectedType: ExpectedType): Ast = { - val maybeInitializerAst = expr.getInitializer.toScala.map(astForArrayInitializerExpr(_, expectedType)) + val elementType = tryWithSafeStackOverflow(expr.getElementType.resolve()).map(elementType => + ExpectedType(typeInfoCalc.fullName(elementType).map(_ ++ "[]"), Option(elementType)) + ) + val maybeInitializerAst = + expr.getInitializer.toScala.map(astForArrayInitializerExpr(_, elementType.getOrElse(expectedType))) maybeInitializerAst.flatMap(_.root) match { case Some(initializerRoot: NewCall) => initializerRoot.code(expr.toString) @@ -68,9 +76,9 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => maybeInitializerAst.getOrElse { val typeFullName = expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback(expr.getElementType)) val callNode = newOperatorCallNode(Operators.alloc, code = expr.toString, typeFullName = Some(typeFullName)) val levelAsts = expr.getLevels.asScala.flatMap { lvl => lvl.getDimension.toScala match { @@ -84,11 +92,12 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => } private[expressions] def astForArrayInitializerExpr(expr: ArrayInitializerExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + // In the expression `new int[] { 1, 2 }`, the ArrayInitializerExpr is only the `{ 1, 2 }` part and does not have + // a type itself. We need to use the expected type from the parent expr here. + val typeFullName = getTypeFullName(expectedType) + .map(typeInfoCalc.registerType) + .getOrElse(defaultTypeFallback()) + val callNode = newOperatorCallNode( Operators.arrayInitializer, code = expr.toString, @@ -117,7 +126,7 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => if (expr.getValues.size() > MAX_INITIALIZERS) { val placeholder = NewLiteral() - .typeFullName(TypeConstants.Any) + .typeFullName(defaultTypeFallback()) .code("") .lineNumber(line(expr)) .columnNumber(column(expr)) @@ -150,20 +159,40 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => case BinaryExpr.Operator.REMAINDER => Operators.modulo } - val args = - astsForExpression(expr.getLeft, expectedType) ++ astsForExpression(expr.getRight, expectedType) + val lhsArgs = astsForExpression(expr.getLeft, expectedType) + // TODO Fix code + // val lhsCode = lhsArgs.headOption.flatMap(_.rootCode).getOrElse("") + + scope.pushBlockScope() + val context = new BinaryExprContext(expr, new CombinedTypeSolver()) + + context + .typePatternExprsExposedToChild(expr.getRight) + .asScala + .flatMap(pattern => scope.enclosingMethod.flatMap(_.getPatternVariableInfo(pattern))) + .foreach { case PatternVariableInfo(pattern, local, _, _, _, _) => + scope.enclosingBlock.foreach(_.addPatternLocal(local, pattern)) + } + + val rhsArgs = astsForExpression(expr.getRight, expectedType) + // TODO Fix code + // val rhsCode = rhsArgs.headOption.flatMap(_.rootCode).getOrElse("") + scope.popBlockScope() + + val args = lhsArgs ++ rhsArgs val typeFullName = expressionReturnTypeFullName(expr) .orElse(args.headOption.flatMap(_.rootType)) .orElse(args.lastOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback()) val callNode = newOperatorCallNode( operatorName, - code = expr.toString, + // code = s"$lhsCode ${expr.getOperator.asString()} $rhsCode", + code = code(expr), typeFullName = Some(typeFullName), line = line(expr), column = column(expr) @@ -175,14 +204,18 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForCastExpr(expr: CastExpr, expectedType: ExpectedType): Ast = { val typeFullName = tryWithSafeStackOverflow(expr.getType).toOption - .flatMap(typeInfoCalc.fullName) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + .map { typ => + typeInfoCalc + .fullName(typ) + .orElse(getTypeFullName(expectedType)) + .getOrElse(defaultTypeFallback(typ)) + } + .getOrElse(defaultTypeFallback()) val callNode = newOperatorCallNode( Operators.cast, code = expr.toString, - typeFullName = Some(typeFullName), + typeFullName = Option(typeFullName), line = line(expr), column = column(expr) ) @@ -224,16 +257,29 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForConditionalExpr(expr: ConditionalExpr, expectedType: ExpectedType): Ast = { val condAst = astsForExpression(expr.getCondition, ExpectedType.Boolean) + + val context = new ConditionalExprContext(expr, new CombinedTypeSolver()) + + val patternsExposedToThen = context.typePatternExprsExposedToChild(expr.getThenExpr).asScala.toList + val patternsExposedToElse = context.typePatternExprsExposedToChild(expr.getElseExpr).asScala.toList + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternsExposedToThen) val thenAst = astsForExpression(expr.getThenExpr, expectedType) + scope.popBlockScope() + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternsExposedToElse) val elseAst = astsForExpression(expr.getElseExpr, expectedType) + scope.popBlockScope() val typeFullName = expressionReturnTypeFullName(expr) .orElse(thenAst.headOption.flatMap(_.rootType)) .orElse(elseAst.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback()) val callNode = newOperatorCallNode(Operators.conditional, expr.toString, Some(typeFullName), line(expr), column(expr)) @@ -248,81 +294,57 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForFieldAccessExpr(expr: FieldAccessExpr, expectedType: ExpectedType): Ast = { val typeFullName = expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) - - val callNode = - newOperatorCallNode(Operators.fieldAccess, expr.toString, Some(typeFullName), line(expr), column(expr)) + .getOrElse(defaultTypeFallback()) val fieldIdentifier = expr.getName val identifierAsts = astsForExpression(expr.getScope, ExpectedType.empty) - val fieldIdentifierNode = NewFieldIdentifier() - .canonicalName(fieldIdentifier.toString) - .lineNumber(line(fieldIdentifier)) - .columnNumber(column(fieldIdentifier)) - .code(fieldIdentifier.toString) - val fieldIdAst = Ast(fieldIdentifierNode) - - callAst(callNode, identifierAsts ++ Seq(fieldIdAst)) - } - private[expressions] def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = { - val booleanTypeFullName = Some(TypeConstants.Boolean) - val callNode = - newOperatorCallNode(Operators.instanceOf, expr.toString, booleanTypeFullName, line(expr), column(expr)) - - val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) - val exprType = tryWithSafeStackOverflow(expr.getType).toOption - val typeFullName = exprType.flatMap(typeInfoCalc.fullName).getOrElse(TypeConstants.Any) - val typeNode = - NewTypeRef() - .code(exprType.map(_.toString).getOrElse(code(expr).split("instanceof").lastOption.getOrElse(""))) - .lineNumber(line(expr)) - .columnNumber(exprType.map(column(_)).getOrElse(column(expr))) - .typeFullName(typeFullName) - val typeAst = Ast(typeNode) - - callAst(callNode, exprAst ++ Seq(typeAst)) + fieldAccessAst( + identifierAsts.head, + expr.toString, + line(expr), + column(expr), + fieldIdentifier.toString, + typeFullName, + line(fieldIdentifier), + column(fieldIdentifier) + ) } - private[expressions] def fieldAccessAst( - identifierName: String, - identifierType: Option[String], - fieldIdentifierName: String, - returnType: Option[String], - lineNo: Option[Int], - columnNo: Option[Int] - ): Ast = { - val typeFullName = identifierType.orElse(Some(TypeConstants.Any)).map(typeInfoCalc.registerType) - val identifier = newIdentifierNode(identifierName, typeFullName.getOrElse("ANY")) - val maybeCorrespNode = scope.lookupVariable(identifierName).variableNode - - val fieldIdentifier = NewFieldIdentifier() - .code(fieldIdentifierName) - .canonicalName(fieldIdentifierName) - .lineNumber(lineNo) - .columnNumber(columnNo) - - val fieldAccessCode = s"$identifierName.$fieldIdentifierName" - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - fieldAccessCode, - returnType.orElse(Some(TypeConstants.Any)), - lineNo, - columnNo - ) - - val identifierAst = Ast(identifier) - val fieldIdentAst = Ast(fieldIdentifier) - - callAst(fieldAccess, Seq(identifierAst, fieldIdentAst)) - .withRefEdges(identifier, maybeCorrespNode.toList) + private[expressions] def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = { + // TODO: handle multiple ASTs + val lhsAst = astsForExpression(expr.getExpression, ExpectedType.empty).head + expr.getPattern.toScala + .map { patternExpression => + instanceOfAstForPattern(patternExpression, lhsAst) + } + .getOrElse { + val booleanTypeFullName = Some(TypeConstants.Boolean) + val callNode = + newOperatorCallNode(Operators.instanceOf, expr.toString, booleanTypeFullName, line(expr), column(expr)) + + val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) + val exprType = tryWithSafeStackOverflow(expr.getType).toOption + val typeFullName = exprType + .map(typ => typeInfoCalc.fullName(typ).getOrElse(defaultTypeFallback(typ))) + .getOrElse(defaultTypeFallback()) + val typeNode = + NewTypeRef() + .code(exprType.map(_.toString).getOrElse(code(expr).split("instanceof").lastOption.getOrElse(""))) + .lineNumber(line(expr)) + .columnNumber(exprType.map(column(_)).getOrElse(column(expr))) + .typeFullName(typeFullName) + val typeAst = Ast(typeNode) + + callAst(callNode, exprAst ++ Seq(typeAst)) + } } private[expressions] def astForLiteralExpr(expr: LiteralExpr): Ast = { - val typeFullName = expressionReturnTypeFullName(expr).map(typeInfoCalc.registerType).getOrElse(TypeConstants.Any) + val typeFullName = + expressionReturnTypeFullName(expr).map(typeInfoCalc.registerType).getOrElse(defaultTypeFallback()) val literalNode = NewLiteral() .code(code(expr)) @@ -335,9 +357,9 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForSuperExpr(superExpr: SuperExpr, expectedType: ExpectedType): Ast = { val typeFullName = expressionReturnTypeFullName(superExpr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback()) val identifier = identifierNode(superExpr, NameConstants.This, NameConstants.Super, typeFullName) Ast(identifier) @@ -346,7 +368,7 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => private[expressions] def astForThisExpr(expr: ThisExpr, expectedType: ExpectedType): Ast = { val typeFullName = expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) val identifier = identifierNode(expr, expr.toString, expr.toString, typeFullName.getOrElse("ANY")) @@ -376,10 +398,11 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => val typeFullName = expressionReturnTypeFullName(expr) .orElse(argsAsts.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) + .orElse(getTypeFullName(expectedType)) .map(typeInfoCalc.registerType) - .getOrElse(TypeConstants.Any) + .getOrElse(defaultTypeFallback()) + // TODO Fix code val callNode = newOperatorCallNode( operatorName, code = expr.toString, @@ -408,12 +431,12 @@ trait AstForSimpleExpressionsCreator { this: AstCreator => case Success(resolvedMethod) => val returnType = tryWithSafeStackOverflow(resolvedMethod.getReturnType).toOption.flatMap(typeInfoCalc.fullName) - val parameterTypes = argumentTypesForMethodLike(Success(resolvedMethod)) + val parameterTypes = argumentTypesForMethodLike(Option(resolvedMethod)) composeSignature(returnType, parameterTypes, resolvedMethod.getNumberOfParams) } val methodFullName = Util.composeMethodFullName(namespacePrefix, methodName, signature) - Ast(methodRefNode(expr, expr.toString, methodFullName, typeFullName.getOrElse(TypeConstants.Any))) + Ast(methodRefNode(expr, expr.toString, methodFullName, typeFullName.getOrElse(defaultTypeFallback()))) } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForVarDeclAndAssignsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForVarDeclAndAssignsCreator.scala index 174eabe7617a..4c293de6a868 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForVarDeclAndAssignsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForVarDeclAndAssignsCreator.scala @@ -19,6 +19,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewIdentifier, NewLocal, NewMember, + NewTypeRef, NewUnknown } import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} @@ -28,6 +29,7 @@ import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional import scala.util.{Failure, Success, Try} import io.joern.javasrc2cpg.scope.JavaScopeElement.PartialInit +import io.joern.x2cpg.utils.NodeBuilders.newIdentifierNode trait AstForVarDeclAndAssignsCreator { this: AstCreator => private val logger = LoggerFactory.getLogger(this.getClass()) @@ -117,15 +119,16 @@ trait AstForVarDeclAndAssignsCreator { this: AstCreator => case None => (None, None) } - val typeFullName = tryWithSafeStackOverflow( + val typeFullNameWithoutArgs = tryWithSafeStackOverflow( variableTypeString .flatMap(scope.lookupType(_, includeWildcards = false)) .orElse(declaratorType.flatMap(typeInfoCalc.fullName)) - ).toOption.flatten.map { typ => - maybeTypeArgs match { - case Some(typeArgs) if keepTypeArguments => s"$typ<${typeArgs.mkString(",")}>" - case _ => typ - } + .orElse(declaratorType.map(typ => defaultTypeFallback(typ))) + ).toOption.flatten.getOrElse(defaultTypeFallback()) + + val typeFullName = maybeTypeArgs match { + case Some(typeArgs) if keepTypeArguments => s"$typeFullNameWithoutArgs<${typeArgs.mkString(",")}>" + case _ => typeFullNameWithoutArgs } val variableName = variableDeclarator.getNameAsString @@ -135,8 +138,15 @@ trait AstForVarDeclAndAssignsCreator { this: AstCreator => // Use type name with generics for code val localCode = s"${declaratorType.map(_.toString).getOrElse("")} ${variableDeclarator.getNameAsString}" + val genericSignature = binarySignatureCalculator.variableBinarySignature(variableDeclarator.getType) val local = - localNode(originNode, variableDeclarator.getNameAsString, localCode, typeFullName.getOrElse(TypeConstants.Any)) + localNode( + originNode, + variableDeclarator.getNameAsString, + localCode, + typeFullName, + genericSignature = Option(genericSignature) + ) scope.enclosingBlock.foreach(_.addLocal(local)) @@ -151,17 +161,13 @@ trait AstForVarDeclAndAssignsCreator { this: AstCreator => case Some(declarationNode) => val assignmentTarget = declarationNode match { case member: NewMember => - val name = - if (scope.isEnclosingScopeStatic) - scope.enclosingTypeDecl.map(_.typeDecl.name).getOrElse(NameConstants.Unknown) - else NameConstants.This - fieldAccessAst( - name, - scope.enclosingTypeDecl.fullName, - declarationNode.name, - Option(declarationNode.typeFullName), - line(originNode), - column(originNode) + createImplicitBaseFieldAccess( + scope.isEnclosingScopeStatic, + scope.enclosingTypeDecl.name.get, + scope.enclosingTypeDecl.fullName.get, + originNode, + variableName, + declarationNode.typeFullName ) case variable => @@ -184,7 +190,7 @@ trait AstForVarDeclAndAssignsCreator { this: AstCreator => initializer, Operators.assignment, "=", - ExpectedType(typeFullName, expectedType), + ExpectedType(Option(typeFullName), expectedType), strippedType ) } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForForLoopsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForForLoopsCreator.scala index 0370d6ccc52e..96c54ca4ad1a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForForLoopsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForForLoopsCreator.scala @@ -2,10 +2,13 @@ package io.joern.javasrc2cpg.astcreation.statements import com.github.javaparser.ast.expr.{Expression, NameExpr} import com.github.javaparser.ast.stmt.{BlockStmt, ForEachStmt, ForStmt} +import com.github.javaparser.symbolsolver.javaparsermodel.contexts.ForStatementContext +import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} import io.joern.javasrc2cpg.scope.NodeTypeInfo import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.x2cpg.Ast +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.utils.NodeBuilders.{newCallNode, newFieldIdentifierNode, newIdentifierNode, newOperatorCallNode} import io.shiftleft.codepropertygraph.generated.nodes.Call.PropertyDefaults import io.shiftleft.codepropertygraph.generated.nodes.{ @@ -18,12 +21,10 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewNode } import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} -import io.shiftleft.passes.IntervalKeyPool import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional -import scala.util.Try trait AstForForLoopsCreator { this: AstCreator => @@ -44,7 +45,9 @@ trait AstForForLoopsCreator { this: AstCreator => s"$IterableNamePrefix${iterableKeyPool.next}" } - def astForFor(stmt: ForStmt): Ast = { + def astsForFor(stmt: ForStmt): List[Ast] = { + val forContext = new ForStatementContext(stmt, new CombinedTypeSolver()) + val forNode = NewControlStructure() .controlStructureType(ControlStructureTypes.FOR) @@ -59,24 +62,41 @@ trait AstForForLoopsCreator { this: AstCreator => astsForExpression(_, ExpectedType.Boolean) } - val updateAsts = stmt.getUpdate.asScala.toList.flatMap { - astsForExpression(_, ExpectedType.empty) + val updateAsts = stmt.getUpdate.asScala.toList match { + case Nil => Nil + + case expressions => + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock( + forContext.typePatternExprsExposedToChild(expressions.head).asScala.toList + ) + val updateAsts = expressions.flatMap(astsForExpression(_, ExpectedType.empty)) + scope.popBlockScope() + updateAsts } - val stmtAsts = - astsForStatement(stmt.getBody) + val patternPartition = partitionPatternAstsByScope(forContext) + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedToBody) + val bodyAst = wrapInBlockWithPrefix(patternPartition.astsAddedToBody, stmt.getBody) + scope.popBlockScope() + + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedByStatement) val ast = Ast(forNode) .withChildren(initAsts) .withChildren(compareAsts) .withChildren(updateAsts) - .withChildren(stmtAsts) + .withChild(bodyAst) - compareAsts.flatMap(_.root) match { + val astWithConditionEdge = compareAsts.flatMap(_.root) match { case c :: Nil => ast.withConditionEdge(forNode, c) case _ => ast } + + patternPartition.astsAddedBeforeStatement ++ (astWithConditionEdge :: patternPartition.astsAddedAfterStatement) } def astForForEach(stmt: ForEachStmt): Seq[Ast] = { @@ -211,9 +231,16 @@ trait AstForForLoopsCreator { this: AstCreator => iterableAsts.head } - val iterableName = nextIterableName() - val iterableLocalNode = localNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) - val iterableLocalAst = Ast(iterableLocalNode) + val iterableName = nextIterableName() + val genericSignature = binarySignatureCalculator.unspecifiedClassType + val iterableLocalNode = localNode( + iterableExpression, + iterableName, + iterableName, + iterableType.getOrElse("ANY"), + genericSignature = Option(genericSignature) + ) + val iterableLocalAst = Ast(iterableLocalNode) val iterableAssignNode = newOperatorCallNode(Operators.assignment, code = "", line = lineNo, typeFullName = iterableType) @@ -231,14 +258,16 @@ trait AstForForLoopsCreator { this: AstCreator => } private def nativeForEachIdxLocalNode(lineNo: Option[Int]): NewLocal = { - val idxName = nextIndexName() - val typeFullName = TypeConstants.Int + val idxName = nextIndexName() + val typeFullName = TypeConstants.Int + val genericSignature = binarySignatureCalculator.variableBinarySignature(TypeConstants.Int) val idxLocal = NewLocal() .name(idxName) .typeFullName(typeFullName) .code(idxName) .lineNumber(lineNo) + .genericSignature(genericSignature) scope.enclosingBlock.get.addLocal(idxLocal) idxLocal } @@ -318,7 +347,10 @@ trait AstForForLoopsCreator { this: AstCreator => Some(variable) } + val genericSignature = + maybeVariable.map(variable => binarySignatureCalculator.variableBinarySignature(variable.getType)) val partialLocalNode = NewLocal().lineNumber(lineNo) + genericSignature.foreach(partialLocalNode.genericSignature(_)) maybeVariable match { case Some(variable) => @@ -341,11 +373,13 @@ trait AstForForLoopsCreator { this: AstCreator => private def iteratorLocalForForEach(lineNumber: Option[Int]): NewLocal = { val iteratorLocalName = nextIterableName() + val genericSignature = binarySignatureCalculator.variableBinarySignature(TypeConstants.Iterator) NewLocal() .name(iteratorLocalName) .code(iteratorLocalName) .typeFullName(TypeConstants.Iterator) .lineNumber(lineNumber) + .genericSignature(genericSignature) } private def iteratorAssignAstForForEach( diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForSimpleStatementsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForSimpleStatementsCreator.scala index 829dea82a3a3..d012f3fab080 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForSimpleStatementsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForSimpleStatementsCreator.scala @@ -1,5 +1,6 @@ package io.joern.javasrc2cpg.astcreation.statements +import com.github.javaparser.ast.expr.{Expression, PatternExpr, TypePatternExpr} import com.github.javaparser.ast.stmt.{ AssertStmt, BlockStmt, @@ -19,20 +20,60 @@ import com.github.javaparser.ast.stmt.{ TryStmt, WhileStmt } +import com.github.javaparser.symbolsolver.javaparsermodel.PatternVariableVisitor +import com.github.javaparser.symbolsolver.javaparsermodel.contexts.{ + DoStatementContext, + ExpressionContext, + IfStatementContext, + SwitchEntryContext, + WhileStatementContext +} +import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} +import io.joern.javasrc2cpg.scope.PatternVariableInfo import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.javasrc2cpg.util.NameConstants import io.joern.x2cpg.Ast import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newModifierNode} -import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewCall, NewControlStructure, NewJumpTarget, NewReturn} +import io.shiftleft.codepropertygraph.generated.nodes.{ + NewBlock, + NewCall, + NewControlStructure, + NewIdentifier, + NewJumpTarget, + NewReturn +} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, EdgeTypes} +import io.joern.x2cpg.utils.AstPropertiesUtil.* import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional import io.joern.javasrc2cpg.scope.JavaScopeElement.PartialInit +import io.joern.javasrc2cpg.scope.Scope.NewVariableNode +import org.slf4j.LoggerFactory + +private case class PatternAstPartition( + patternsIntroducedToBody: List[TypePatternExpr], + patternsIntroducedToElse: List[TypePatternExpr], + patternsIntroducedByStatement: List[TypePatternExpr], + astsAddedBeforeStatement: List[Ast], + astsAddedToBody: List[Ast], + astsAddedToElse: List[Ast], + astsAddedAfterStatement: List[Ast] +) { + val addedAsts: Set[Ast] = + (astsAddedBeforeStatement ++ astsAddedToBody ++ astsAddedToElse ++ astsAddedAfterStatement).toSet +} trait AstForSimpleStatementsCreator { this: AstCreator => - def astForBlockStatement(stmt: BlockStmt, codeStr: String = "", prefixAsts: Seq[Ast] = Seq.empty): Ast = { + private val logger = LoggerFactory.getLogger(this.getClass) + + def astForBlockStatement( + stmt: BlockStmt, + codeStr: String = "", + prefixAsts: Seq[Ast] = Seq.empty, + includeTemporaryLocals: Boolean = false + ): Ast = { val block = NewBlock() .code(codeStr) .lineNumber(line(stmt)) @@ -42,8 +83,15 @@ trait AstForSimpleStatementsCreator { this: AstCreator => val stmtAsts = stmt.getStatements.asScala.flatMap(astsForStatement) + val temporaryLocalAsts = + if (includeTemporaryLocals) + scope.enclosingMethod.map(_.getTemporaryLocals).getOrElse(Nil).map(Ast(_)) + else + Nil + scope.popBlockScope() Ast(block) + .withChildren(temporaryLocalAsts) .withChildren(prefixAsts) .withChildren(stmtAsts) } @@ -52,22 +100,18 @@ trait AstForSimpleStatementsCreator { this: AstCreator => // TODO Handle super val maybeResolved = tryWithSafeStackOverflow(stmt.resolve()) val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments) - val argTypes = argumentTypesForMethodLike(maybeResolved) + val argTypes = argumentTypesForMethodLike(maybeResolved.toOption) + // TODO: We can do better than defaultTypeFallback() for the fallback type by looking at the enclosing + // type decl name or `extends X` name for `this` and `super` calls respectively. val typeFullName = maybeResolved.toOption .map(_.declaringType()) .flatMap(typ => scope.lookupType(typ.getName).orElse(typeInfoCalc.fullName(typ))) + .getOrElse(defaultTypeFallback()) - val callRoot = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), - argTypes, - args.size, - stmt.toString, - line(stmt), - column(stmt) - ) + val callRoot = initNode(Option(typeFullName), argTypes, args.size, stmt.toString, line(stmt), column(stmt)) - val thisNode = newIdentifierNode(NameConstants.This, typeFullName.getOrElse(TypeConstants.Any)) + val thisNode = newIdentifierNode(NameConstants.This, typeFullName) scope.lookupVariable(NameConstants.This).variableNode.foreach { thisParam => diffGraph.addEdge(thisNode, thisParam, EdgeTypes.REF) } @@ -77,9 +121,7 @@ trait AstForSimpleStatementsCreator { this: AstCreator => // callAst(callRoot, args, Some(thisAst)) scope.enclosingTypeDecl.foreach( - _.registerInitToComplete( - PartialInit(typeFullName.getOrElse(TypeConstants.Any), initAst, thisAst, args.toList, None) - ) + _.registerInitToComplete(PartialInit(typeFullName, initAst, thisAst, args.toList, None)) ) initAst } @@ -115,51 +157,106 @@ trait AstForSimpleStatementsCreator { this: AstCreator => Ast(node) } - private[statements] def astForDo(stmt: DoStmt): Ast = { + private[statements] def astsForDo(stmt: DoStmt): List[Ast] = { val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) + + val doContext = new DoStatementContext(stmt, new CombinedTypeSolver()) + + val patternPartition = partitionPatternAstsByScope(doContext) + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedToBody) + val bodyAst = wrapInBlockWithPrefix(patternPartition.astsAddedToBody, stmt.getBody) + scope.popBlockScope() + + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedByStatement) + val code = s"do {...} while (${stmt.getCondition.toString})" val lineNumber = line(stmt) val columnNumber = column(stmt) - doWhileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + val ast = doWhileAst(conditionAst, List(bodyAst), Some(code), lineNumber, columnNumber) + patternPartition.astsAddedBeforeStatement ++ (ast :: patternPartition.astsAddedAfterStatement) } - private[statements] def astForWhile(stmt: WhileStmt): Ast = { + private[statements] def astsForWhile(stmt: WhileStmt): List[Ast] = { val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) + + val whileContext = new WhileStatementContext(stmt, new CombinedTypeSolver()) + + val patternPartition = partitionPatternAstsByScope(whileContext) + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedToBody) + val bodyAst = wrapInBlockWithPrefix(patternPartition.astsAddedToBody, stmt.getBody) + scope.popBlockScope() + + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedByStatement) + val code = s"while (${stmt.getCondition.toString})" val lineNumber = line(stmt) val columnNumber = column(stmt) - whileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + val ast = whileAst(conditionAst, List(bodyAst), Some(code), lineNumber, columnNumber) + patternPartition.astsAddedBeforeStatement ++ (ast :: patternPartition.astsAddedAfterStatement) } - private[statements] def astForIf(stmt: IfStmt): Ast = { + private[statements] def astsForIf(stmt: IfStmt): Seq[Ast] = { + + val conditionAst = + astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList + + val ifContext = new IfStatementContext(stmt, new CombinedTypeSolver()) + + val patternPartition = partitionPatternAstsByScope(ifContext) + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedToBody) + val thenAst = stmt.getThenStmt match { + case blockStmt: BlockStmt => astForBlockStatement(blockStmt, prefixAsts = patternPartition.astsAddedToBody) + + case stmt: Statement => + val elseStmts = astsForStatement(stmt) + blockAst(blockNode(stmt), patternPartition.astsAddedToBody ++ elseStmts) + } + scope.popBlockScope() + + scope.pushBlockScope() + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedToElse) + val elseAst = stmt.getElseStmt.toScala.map { elseStmt => + val elseBodyStatements = elseStmt match { + case blockStmt: BlockStmt => blockStmt.getStatements.asScala + case elseStmt: Statement => elseStmt :: Nil + } + + val elseBodyAsts = elseBodyStatements.flatMap(astsForStatement) + val elseBlock = blockAst(blockNode(elseStmt), patternPartition.astsAddedToElse ++ elseBodyAsts) + val elseNode = controlStructureNode(elseStmt, ControlStructureTypes.ELSE, "else") + controlStructureAst(elseNode, None, elseBlock :: Nil) + } + scope.popBlockScope() + val ifNode = NewControlStructure() .controlStructureType(ControlStructureTypes.IF) .lineNumber(line(stmt)) .columnNumber(column(stmt)) - .code(s"if (${stmt.getCondition.toString})") - - val conditionAst = - astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList - - val thenAsts = astsForStatement(stmt.getThenStmt) - val elseAst = astForElse(stmt.getElseStmt.toScala).toList + .code(s"if (${conditionAst.headOption.flatMap(_.rootCode).getOrElse("")})") + scope.addLocalsForPatternsToEnclosingBlock(patternPartition.patternsIntroducedByStatement) val ast = Ast(ifNode) .withChildren(conditionAst) - .withChildren(thenAsts) - .withChildren(elseAst) + .withChild(thenAst) + .withChildren(elseAst.toList) - conditionAst.flatMap(_.root.toList) match { + val astWithConditionEdge = conditionAst.flatMap(_.root.toList) match { case r :: Nil => ast.withConditionEdge(ifNode, r) case _ => ast } + + patternPartition.astsAddedBeforeStatement ++ (astWithConditionEdge :: patternPartition.astsAddedAfterStatement) } private[statements] def astForElse(maybeStmt: Option[Statement]): Option[Ast] = { @@ -178,20 +275,42 @@ trait AstForSimpleStatementsCreator { this: AstCreator => } private[statements] def astForSwitchStatement(stmt: SwitchStmt): Ast = { + // TODO: Add support for switch expressions + // TODO: Switch expressions should either be represented with MATCH or we should add a break. val switchNode = NewControlStructure() .controlStructureType(ControlStructureTypes.SWITCH) .code(s"switch(${stmt.getSelector.toString})") - val selectorAsts = astsForExpression(stmt.getSelector, ExpectedType.empty) - val selectorNode = selectorAsts.head.root.get + val selectorAst = astsForExpression(stmt.getSelector, ExpectedType.empty) match { + case Seq() => + throw new IllegalArgumentException(s"Got an empty ast list for expression ${code(stmt.getSelector)}") + + case Seq(ast) => ast + + case asts => + logger.warn(s"Found multiple asts for selector expression ${code(stmt.getSelector)}") + asts.head + } + + val selectorNode = selectorAst.root.get + + val selectorMustBeIdentifierOrFieldAccess = + stmt.getEntries.asScala.flatMap(_.getLabels.asScala).exists(_.isPatternExpr) + + val (initializerAst, referenceAst) = if (selectorMustBeIdentifierOrFieldAccess) { + val initAndRefAsts = initAndRefAstsForPatternInitializer(stmt.getSelector, selectorAst) + (initAndRefAsts.get, Option(initAndRefAsts.get)) + } else { + (selectorAst, None) + } - val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry) + val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry(_, referenceAst)) val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts) Ast(switchNode) - .withChildren(selectorAsts) + .withChild(initializerAst) .withChild(switchBodyAst) .withConditionEdge(switchNode, selectorNode) } @@ -213,32 +332,88 @@ trait AstForSimpleStatementsCreator { this: AstCreator => .withChild(bodyAst) } - private[statements] def astsForSwitchCases(entry: SwitchEntry): Seq[Ast] = { - entry.getLabels.asScala.toList match { - case Nil => - val target = NewJumpTarget() - .name("default") - .code("default") - Seq(Ast(target)) - - case labels => - labels.flatMap { label => - val jumpTarget = NewJumpTarget() - .name("case") - .code(label.toString) - val labelAsts = astsForExpression(label, ExpectedType.empty).toList - - Ast(jumpTarget) :: labelAsts - } + private def astsForSwitchLabels(labels: List[Expression], isDefault: Boolean): Seq[Ast] = { + val defaultAst = Option.when(labels.isEmpty || isDefault) { + val target = NewJumpTarget() + .name("default") + .code("default") + Ast(target) + } + + val explicitLabelAsts = labels.flatMap { label => + val jumpTarget = NewJumpTarget() + .name("case") + .code(code(label)) + val labelAsts = + if (label.isPatternExpr) + Nil + else + astsForExpression(label, ExpectedType.empty).toList + + Ast(jumpTarget) :: labelAsts } + + (defaultAst ++ explicitLabelAsts).toList } - private[statements] def astForSwitchEntry(entry: SwitchEntry): Seq[Ast] = { - val labelAsts = astsForSwitchCases(entry) + private def astForSwitchEntry(entry: SwitchEntry, selectorReferenceAst: Option[Ast]): Seq[Ast] = { + // Fallthrough to/from a pattern is a compile error, so an entry can only have a pattern label if that is + // the only label + val labels = entry.getLabels.asScala.toList + val labelAsts = astsForSwitchLabels(labels, entry.isDefault) - val statementAsts = entry.getStatements.asScala.flatMap(astsForStatement) + val entryContext = new SwitchEntryContext(entry, new CombinedTypeSolver()) + + val instanceOfAst = labels.lastOption.collect { case patternExpr: PatternExpr => + selectorReferenceAst.map { selectorAst => + instanceOfAstForPattern(patternExpr, selectorAst) + } + }.flatten + + // TODO: Add variable local and assignment to entry body even if there are no statements + if (entry.getStatements.isEmpty) { + labelAsts + } else { + scope.pushBlockScope() + val patternsExposedToBody = entryContext.typePatternExprsExposedToChild(entry.getStatements.get(0)).asScala.toList + scope.addLocalsForPatternsToEnclosingBlock(patternsExposedToBody) + val patternAstsToAdd = patternsExposedToBody + .flatMap(typePattern => scope.enclosingMethod.get.getPatternVariableInfo(typePattern)) + .flatMap { case PatternVariableInfo(typePatternExpr, patternLocal, initializerAst, _, _, _) => + scope.enclosingMethod.get.registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr) + scope.enclosingMethod.get.registerPatternVariableLocalToBeAddedToGraph(typePatternExpr) + Ast(patternLocal) :: initializerAst :: Nil + } - labelAsts ++ statementAsts + val guardAst = entry.getGuard.toScala.map(astsForExpression(_, ExpectedType.Boolean)) + + val statementsAst = guardAst match { + case None => wrapInBlockWithPrefix(patternAstsToAdd, entry.getStatements.asScala.toList) + + case Some(guard) => + val bodyAst = wrapInBlockWithPrefix(Nil, entry.getStatements.asScala.toList) + + val ifNode = controlStructureNode( + entry.getGuard.get(), + ControlStructureTypes.IF, + s"if (${guard.headOption.flatMap(_.rootCode)})" + ) + + val ifAst = controlStructureAst(ifNode, guard.headOption, bodyAst :: Nil) + patternAstsToAdd match { + case Nil => ifAst + case _ => blockAst(blockNode(entry)).withChildren(patternAstsToAdd).withChild(ifAst) + } + } + scope.popBlockScope() + + instanceOfAst + .map { instanceOfAst => + val ifNode = controlStructureNode(entry, ControlStructureTypes.IF, s"if (${instanceOfAst.rootCodeOrEmpty})") + labelAsts :+ controlStructureAst(ifNode, Option(instanceOfAst), statementsAst :: Nil) + } + .getOrElse(labelAsts :+ statementsAst) + } } private[statements] def astForReturnNode(ret: ReturnStmt): Ast = { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForStatementsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForStatementsCreator.scala index 5613e646cd0d..2a16c8d3845c 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForStatementsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/statements/AstForStatementsCreator.scala @@ -1,5 +1,6 @@ package io.joern.javasrc2cpg.astcreation.statements +import com.github.javaparser.ast.expr.TypePatternExpr import com.github.javaparser.ast.stmt.{ AssertStmt, BlockStmt, @@ -25,6 +26,15 @@ import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType} import io.joern.x2cpg.Ast import org.slf4j.LoggerFactory import com.github.javaparser.ast.stmt.LocalClassDeclarationStmt +import com.github.javaparser.symbolsolver.javaparsermodel.contexts.StatementContext +import io.shiftleft.codepropertygraph.generated.nodes.NewBlock +import io.joern.javasrc2cpg.scope.PatternVariableInfo + +import java.util +import java.util.Collections +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.jdk.OptionConverters.* trait AstForStatementsCreator extends AstForSimpleStatementsCreator with AstForForLoopsCreator { this: AstCreator => @@ -35,31 +45,158 @@ trait AstForStatementsCreator extends AstForSimpleStatementsCreator with AstForF // case _: LocalClassDeclarationStmt => Seq() // case _: LocalRecordDeclarationStmt => Seq() // case _: YieldStmt => Seq() - statement match { + val statementAsts = statement match { case x: ExplicitConstructorInvocationStmt => Seq(astForExplicitConstructorInvocation(x)) case x: AssertStmt => Seq(astForAssertStatement(x)) case x: BlockStmt => Seq(astForBlockStatement(x)) case x: BreakStmt => Seq(astForBreakStatement(x)) case x: ContinueStmt => Seq(astForContinueStatement(x)) - case x: DoStmt => Seq(astForDo(x)) + case x: DoStmt => astsForDo(x) case _: EmptyStmt => Seq() // Intentionally skipping this case x: ExpressionStmt => astsForExpression(x.getExpression, ExpectedType.Void) case x: ForEachStmt => astForForEach(x) - case x: ForStmt => Seq(astForFor(x)) - case x: IfStmt => Seq(astForIf(x)) + case x: ForStmt => astsForFor(x) + case x: IfStmt => astsForIf(x) case x: LabeledStmt => astsForLabeledStatement(x) case x: ReturnStmt => Seq(astForReturnNode(x)) case x: SwitchStmt => Seq(astForSwitchStatement(x)) case x: SynchronizedStmt => Seq(astForSynchronizedStatement(x)) case x: ThrowStmt => Seq(astForThrow(x)) case x: TryStmt => astsForTry(x) - case x: WhileStmt => Seq(astForWhile(x)) + case x: WhileStmt => astsForWhile(x) case x: LocalClassDeclarationStmt => Seq(astForLocalClassDeclaration(x)) case x => logger.warn(s"Attempting to generate AST for unknown statement of type ${x.getClass}") Seq(unknownAst(x)) } + + val patternVariableAsts = + scope.enclosingMethod + .map { enclosingMethod => enclosingMethod.getUnaddedPatternVariableAstsAndMarkAdded() } + .getOrElse(Nil) + patternVariableAsts ++ statementAsts + } + + private[statements] def partitionPatternAstsByScope(context: StatementContext[?]): PatternAstPartition = { + val patternsIntroducedByStmt = + Collections.newSetFromMap(new util.IdentityHashMap[TypePatternExpr, java.lang.Boolean]()) + patternsIntroducedByStmt.addAll(context.getIntroducedTypePatterns) + + val patternsIntroducedToBody = + Collections.newSetFromMap(new util.IdentityHashMap[TypePatternExpr, java.lang.Boolean](2)) + + val body = context.getWrappedNode match { + case ifStmt: IfStmt => ifStmt.getThenStmt + case whileStmt: WhileStmt => whileStmt.getBody + case forStmt: ForStmt => forStmt.getBody + case forEachStmt: ForEachStmt => forEachStmt.getBody + case doStmt: DoStmt => doStmt.getBody + case other => + throw new IllegalArgumentException( + s"Trying to partition pattern asts for invalid node type ${other.getClass.getName}" + ) + } + patternsIntroducedToBody.addAll(context.typePatternExprsExposedToChild(body)) + + val patternsIntroducedToElse = + Collections.newSetFromMap(new util.IdentityHashMap[TypePatternExpr, java.lang.Boolean]()) + + context.getWrappedNode match { + case ifStmt: IfStmt if ifStmt.getElseStmt.isPresent => + patternsIntroducedToElse.addAll(context.typePatternExprsExposedToChild(ifStmt.getElseStmt.get())) + case _ => // Nothing to do in this case + } + + val astsAddedBeforeStmt = mutable.ListBuffer[Ast]() + val astsAddedAfterStmt = mutable.ListBuffer[Ast]() + val astsAddedToBody = mutable.ListBuffer[Ast]() + val astsAddedToElse = mutable.ListBuffer[Ast]() + + val patternSet = Collections.newSetFromMap(new util.IdentityHashMap[TypePatternExpr, java.lang.Boolean]()) + + // patterns that are introduced or used in the comparison expression, but not introduced to the + // then or else blocks, or the outer scope. + val patternsDefinedInConditions = context.getWrappedNode + .match { + case ifStmt: IfStmt => Seq(ifStmt.getCondition) + case whileStmt: WhileStmt => Seq(whileStmt.getCondition) + case forEachStmt: ForEachStmt => Seq() + case doStmt: DoStmt => Seq(doStmt.getCondition) + case forStmt: ForStmt => + forStmt.getInitialization.asScala ++ forStmt.getCompare.toScala ++ forStmt.getUpdate.asScala + } + .flatMap(_.findAll(classOf[TypePatternExpr]).asScala) + + patternSet.addAll(patternsDefinedInConditions.asJava) + + patternSet.asScala + .flatMap(patternExpr => scope.enclosingMethod.flatMap(_.getPatternVariableInfo(patternExpr))) + .toArray + .sortBy(_.index) + .foreach { + case PatternVariableInfo(pattern, variableLocal, _, _, true, _) => + scope.enclosingMethod.foreach(_.registerPatternVariableLocalToBeAddedToGraph(pattern)) + astsAddedBeforeStmt.addOne(Ast(variableLocal)) + + case PatternVariableInfo(pattern, variableLocal, initializer, _, false, _) => + if (patternsIntroducedByStmt.contains(pattern)) { + if (patternsIntroducedToBody.contains(pattern) || patternsIntroducedToElse.contains(pattern)) { + astsAddedBeforeStmt.addOne(Ast(variableLocal)) + astsAddedBeforeStmt.addOne(initializer) + } else { + astsAddedAfterStmt.addOne(Ast(variableLocal)) + astsAddedAfterStmt.addOne(initializer) + } + } else { + if (patternsIntroducedToBody.contains(pattern)) { + astsAddedToBody.addOne(Ast(variableLocal)) + astsAddedToBody.addOne(initializer) + } else if (patternsIntroducedToElse.contains(pattern)) { + astsAddedToElse.addOne(Ast(variableLocal)) + astsAddedToElse.addOne(initializer) + } + } + scope.enclosingMethod.foreach(_.registerPatternVariableInitializerToBeAddedToGraph(pattern)) + scope.enclosingMethod.foreach(_.registerPatternVariableLocalToBeAddedToGraph(pattern)) + + } + + PatternAstPartition( + patternsIntroducedToBody.asScala.toList, + patternsIntroducedToElse.asScala.toList, + patternsIntroducedByStmt.asScala.toList, + astsAddedBeforeStmt.toList, + astsAddedToBody.toList, + astsAddedToElse.toList, + astsAddedAfterStmt.toList + ) + } + + private[statements] def wrapInBlockWithPrefix(prefixAsts: List[Ast], stmt: Statement): Ast = { + wrapInBlockWithPrefix(prefixAsts, stmt :: Nil) + } + + private[statements] def wrapInBlockWithPrefix(prefixAsts: List[Ast], stmts: List[Statement]): Ast = { + stmts match { + case Seq() => Ast(NewBlock()).withChildren(prefixAsts) + case Seq(blockStmt: BlockStmt) => astForBlockStatement(blockStmt, prefixAsts = prefixAsts) + + case Seq(singleStmt) => + val stmtAsts = astsForStatement(singleStmt) + stmtAsts.toList match { + case bodyStmt :: Nil if prefixAsts.isEmpty => bodyStmt + case _ => blockAst(blockNode(singleStmt), prefixAsts ++ stmtAsts) + } + + case _ => + val stmtsAsts = stmts.flatMap(astsForStatement) + stmtsAsts match { + case Nil => Ast(NewBlock()).withChildren(prefixAsts) + case bodyStmt :: Nil if prefixAsts.isEmpty => bodyStmt + case _ => blockAst(blockNode(stmts.head), prefixAsts ++ stmtsAsts) + } + } } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala index 2984520ca54c..e053b780a2f0 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.jartypereader.descriptorparser import io.joern.javasrc2cpg.jartypereader.model.PrimitiveType -import io.joern.javasrc2cpg.jartypereader.model.Model.TypeConstants._ +import io.joern.javasrc2cpg.jartypereader.model.Model.TypeConstants.* import org.slf4j.LoggerFactory import scala.util.parsing.combinator.RegexParsers diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala index d1f735573013..6694c27abf3a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.jartypereader.descriptorparser import io.joern.javasrc2cpg.jartypereader.model.Bound.{BoundAbove, BoundBelow} -import io.joern.javasrc2cpg.jartypereader.model._ +import io.joern.javasrc2cpg.jartypereader.model.* import org.slf4j.LoggerFactory trait TypeParser extends TokenParser { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala index 4694b137d250..4f3050691497 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala @@ -63,7 +63,7 @@ class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[Str symbolSolver, config.keepTypeArguments, loggedExceptionCounts - )(config.schemaValidation) + )(config.schemaValidation, config.disableTypeFallback) .createAst() ) @@ -87,7 +87,10 @@ class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[Str private def getDependencyList(inputPath: String): List[String] = { val envVarValue = Option(System.getenv(JavaSrcEnvVar.FetchDependencies.name)) - val shouldFetch = if (envVarValue.exists(_.nonEmpty)) { + val shouldFetch = if (envVarValue.contains("no-fetch")) { + logger.info(s"Disabling dependency fetching as envvar is set to \"no-fetch\"") + false + } else if (envVarValue.exists(_.nonEmpty)) { logger.info(s"Enabling dependency fetching: Environment variable ${JavaSrcEnvVar.FetchDependencies.name} is set") true } else if (config.fetchDependencies) { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala new file mode 100644 index 000000000000..1c4c6023644b --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala @@ -0,0 +1,24 @@ +package io.joern.javasrc2cpg.passes; + +import io.joern.javasrc2cpg.util.NameConstants +import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.nodes.{Method, TypeDecl} +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* + +class OuterClassRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[TypeDecl](cpg) { + override def generateParts(): Array[TypeDecl] = cpg.typeDecl.toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, typeDecl: TypeDecl): Unit = { + typeDecl.method.nameExact(Defines.ConstructorMethodName).foreach { constructor => + constructor.ast.isIdentifier.nameExact(NameConstants.OuterClass).filter(_.refsTo.isEmpty).foreach { + outerClassIdentifier => + constructor.parameter.nameExact(NameConstants.OuterClass).foreach { outerClassParam => + diffGraph.addEdge(outerClassIdentifier, outerClassParam, EdgeTypes.REF) + } + } + } + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala index 34cd3c42241c..a92e51fe2c42 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala @@ -3,11 +3,10 @@ package io.joern.javasrc2cpg.passes import com.github.javaparser.symbolsolver.cache.GuavaCache import com.google.common.cache.CacheBuilder import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.codepropertygraph.generated.{Cpg, ModifierTypes, Properties} import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory import scala.jdk.OptionConverters.RichOptional @@ -55,11 +54,8 @@ class TypeInferencePass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { val callArgs = if (skipCallThis) call.argument.toList.tail else call.argument.toList val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => - val maybeArgumentType = Option(argument.property(PropertyNames.TypeFullName)) - .map(_.toString()) - .getOrElse(TypeConstants.Any) - - val argMatches = maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName + val maybeArgumentType = argument.propertyOption(Properties.TypeFullName).getOrElse(TypeConstants.Any) + val argMatches = maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName !argMatches } @@ -80,10 +76,8 @@ class TypeInferencePass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { } private def getReplacementMethod(call: Call): Option[Method] = { - val argTypes = - call.argument.flatMap(arg => Option(arg.property(PropertyNames.TypeFullName)).map(_.toString)).mkString(":") - val callKey = - s"${call.methodFullName}:$argTypes" + val argTypes = call.argument.property(Properties.TypeFullName).mkString(":") + val callKey = s"${call.methodFullName}:$argTypes" cache.get(callKey).toScala.getOrElse { val callNameParts = getNameParts(call.name, call.methodFullName) resolvedMethodIndex.get(call.name).flatMap { candidateMethods => diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala index 04f5df7ea105..30a003445e31 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala @@ -1,23 +1,35 @@ package io.joern.javasrc2cpg.scope +import com.github.javaparser.ast.body.Parameter +import com.github.javaparser.ast.expr.TypePatternExpr import io.joern.javasrc2cpg.scope.Scope.* import io.joern.javasrc2cpg.scope.JavaScopeElement.* import io.shiftleft.codepropertygraph.generated.nodes.{NewImport, NewMethod, NewNamespaceBlock, NewTypeDecl} import scala.collection.mutable import io.joern.javasrc2cpg.astcreation.ExpectedType +import io.joern.javasrc2cpg.scope.TypeType.{ReferenceTypeType, TypeVariableType} import io.joern.javasrc2cpg.util.MultiBindingTableAdapterForJavaparser.JavaparserBindingDeclType import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn import io.shiftleft.codepropertygraph.generated.nodes.NewLocal import io.shiftleft.codepropertygraph.generated.nodes.NewMember import io.joern.javasrc2cpg.util.{BindingTable, BindingTableEntry, NameConstants} -import io.shiftleft.passes.IntervalKeyPool -import io.joern.x2cpg.Ast +import io.joern.x2cpg.utils.IntervalKeyPool +import io.joern.x2cpg.{Ast, ValidationMode} -trait JavaScopeElement { +import java.util +import scala.jdk.CollectionConverters.* + +enum TypeType: + case ReferenceTypeType, TypeVariableType + +trait JavaScopeElement(disableTypeFallback: Boolean) { private val variables = mutable.Map[String, ScopeVariable]() private val types = mutable.Map[String, ScopeType]() private var wildcardImports: WildcardImports = NoWildcard + // TODO: This is almost a duplicate of types, but not quite since this stores type variables and local types + // with original names. See if there's a way to combine them + private val declaredTypeTypes = mutable.Map[String, TypeType]() def isStatic: Boolean @@ -25,6 +37,21 @@ trait JavaScopeElement { variables.put(variable.name, variable) } + def addDeclaredTypeType(typeSimpleName: String, isTypeVariable: Boolean): Unit = { + val typeType = + if (isTypeVariable) + TypeVariableType + else { + ReferenceTypeType + } + + declaredTypeTypes.put(typeSimpleName, typeType) + } + + def getDeclaredTypeType(typeSimpleName: String): Option[TypeType] = { + declaredTypeTypes.get(typeSimpleName) + } + def lookupVariable(name: String): VariableLookupResult = { variables.get(name).map(SimpleVariable(_)).getOrElse(NotInScope) } @@ -35,14 +62,14 @@ trait JavaScopeElement { def lookupType(name: String, includeWildcards: Boolean): Option[ScopeType] = { types.get(name) match { - case None if includeWildcards => getNameWithWildcardPrefix(name) - case result => result + case None if includeWildcards && !disableTypeFallback => getNameWithWildcardPrefix(name) + case result => result } } def getNameWithWildcardPrefix(name: String): Option[ScopeType] = { wildcardImports match { - case SingleWildcard(prefix) => Some(ScopeTopLevelType(s"$prefix.$name")) + case SingleWildcard(prefix) => Some(ScopeTopLevelType(s"$prefix.$name", name)) case _ => None } @@ -61,6 +88,15 @@ trait JavaScopeElement { def getVariables(): List[ScopeVariable] = variables.values.toList } +case class PatternVariableInfo( + typePatternExpr: TypePatternExpr, + typeVariableLocal: NewLocal, + typeVariableInitializer: Ast, + localAddedToAst: Boolean = false, + initializerAddedToAst: Boolean = false, + index: Int +) + object JavaScopeElement { sealed trait WildcardImports case object NoWildcard extends WildcardImports @@ -73,35 +109,105 @@ object JavaScopeElement { def getNextAnonymousClassIndex(): Long = anonymousClassKeyPool.next } - class NamespaceScope(val namespace: NewNamespaceBlock) extends JavaScopeElement with TypeDeclContainer { + class NamespaceScope(val namespace: NewNamespaceBlock)(implicit disableTypeFallback: Boolean) + extends JavaScopeElement(disableTypeFallback) + with TypeDeclContainer { val isStatic = false } - class BlockScope extends JavaScopeElement { + class BlockScope(implicit disableTypeFallback: Boolean) extends JavaScopeElement(disableTypeFallback) { val isStatic = false def addLocal(local: NewLocal): Unit = { addVariableToScope(ScopeLocal(local)) } + + def addPatternLocal(local: NewLocal, typePatternExpr: TypePatternExpr): Unit = { + addVariableToScope(ScopePatternVariable(local, typePatternExpr)) + } } - class MethodScope(val method: NewMethod, val returnType: ExpectedType, override val isStatic: Boolean) - extends JavaScopeElement + class MethodScope(val method: NewMethod, val returnType: ExpectedType, override val isStatic: Boolean)(implicit + val withSchemaValidation: ValidationMode, + disableTypeFallback: Boolean + ) extends JavaScopeElement(disableTypeFallback) with AnonymousClassCounter { - def addParameter(parameter: NewMethodParameterIn): Unit = { - addVariableToScope(ScopeParameter(parameter)) + + private val temporaryLocals = mutable.ListBuffer[NewLocal]() + private val patternVariableInfoIdentityMap: mutable.Map[TypePatternExpr, PatternVariableInfo] = + new util.IdentityHashMap[TypePatternExpr, PatternVariableInfo]().asScala + // The insertion order should be preserved to ensure stable results when getting unadded variable asts + private var patternVariableIndex = 0 + + def addParameter(parameter: NewMethodParameterIn, genericSignature: String): Unit = { + addVariableToScope(ScopeParameter(parameter, genericSignature)) + } + + def addTemporaryLocal(local: NewLocal): Unit = { + temporaryLocals.addOne(local) + } + + def getTemporaryLocals: List[NewLocal] = temporaryLocals.toList + + def putPatternVariableInfo( + typePatternExpr: TypePatternExpr, + typeVariableLocal: NewLocal, + typeVariableInitializer: Ast + ): Unit = { + patternVariableInfoIdentityMap.put( + typePatternExpr, + PatternVariableInfo(typePatternExpr, typeVariableLocal, typeVariableInitializer, index = patternVariableIndex) + ) + patternVariableIndex += 1 + } + + def getPatternVariableInfo(typePatternExpr: TypePatternExpr): Option[PatternVariableInfo] = { + patternVariableInfoIdentityMap.get(typePatternExpr) + } + + def registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr: TypePatternExpr): Unit = { + patternVariableInfoIdentityMap.get(typePatternExpr).foreach { patternVariableInfo => + patternVariableInfoIdentityMap + .put(typePatternExpr, patternVariableInfo.copy(initializerAddedToAst = true)) + } + } + + def registerPatternVariableLocalToBeAddedToGraph(typePatternExpr: TypePatternExpr): Unit = { + patternVariableInfoIdentityMap.get(typePatternExpr).foreach { patternVariableInfo => + patternVariableInfoIdentityMap.put(typePatternExpr, patternVariableInfo.copy(localAddedToAst = true)) + } + } + + def getUnaddedPatternVariableAstsAndMarkAdded(): List[Ast] = { + val result = mutable.ListBuffer[Ast]() + patternVariableInfoIdentityMap.values.toArray.sortBy(_.index).foreach { patternInfo => + if (!patternInfo.localAddedToAst) { + result.addOne(Ast(patternInfo.typeVariableLocal)) + registerPatternVariableLocalToBeAddedToGraph(patternInfo.typePatternExpr) + } + + if (!patternInfo.initializerAddedToAst) { + result.addOne(patternInfo.typeVariableInitializer) + registerPatternVariableInitializerToBeAddedToGraph(patternInfo.typePatternExpr) + } + } + result.toList } } - class FieldDeclScope(override val isStatic: Boolean, val name: String) extends JavaScopeElement + class FieldDeclScope(override val isStatic: Boolean, val name: String)(implicit disableTypeFallback: Boolean) + extends JavaScopeElement(disableTypeFallback) class TypeDeclScope( val typeDecl: NewTypeDecl, override val isStatic: Boolean, private[scope] val capturedVariables: Map[String, CapturedVariable], outerClassType: Option[String], - val declaredMethodNames: Set[String] - ) extends JavaScopeElement + outerClassGenericSignature: Option[String], + val declaredMethodNames: Set[String], + val recordParameters: List[Parameter] + )(implicit disableTypeFallback: Boolean) + extends JavaScopeElement(disableTypeFallback) with TypeDeclContainer with AnonymousClassCounter { private val bindingTableEntries = mutable.ListBuffer[BindingTableEntry]() @@ -141,7 +247,9 @@ object JavaScopeElement { def getUsedCaptures(): List[ScopeVariable] = { val outerScope = outerClassType.map(typ => - ScopeLocal(NewLocal().name(NameConstants.OuterClass).typeFullName(typ).code(NameConstants.OuterClass)) + val localNode = NewLocal().name(NameConstants.OuterClass).typeFullName(typ).code(NameConstants.OuterClass) + outerClassGenericSignature.foreach(localNode.genericSignature(_)) + ScopeLocal(localNode) ) val sortedUsedCaptures = usedCaptureParams.toList.sortBy(_.name) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/Scope.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/Scope.scala index b0025ca254be..04a1c9fce729 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/Scope.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/Scope.scala @@ -1,11 +1,13 @@ package io.joern.javasrc2cpg.scope +import com.github.javaparser.ast.body.Parameter +import com.github.javaparser.ast.expr.TypePatternExpr import io.joern.javasrc2cpg.astcreation.ExpectedType import io.joern.javasrc2cpg.scope.Scope.* import io.joern.javasrc2cpg.scope.JavaScopeElement.* import io.joern.javasrc2cpg.util.MultiBindingTableAdapterForJavaparser.JavaparserBindingDeclType import io.joern.javasrc2cpg.util.NameConstants -import io.joern.x2cpg.Ast +import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.utils.ListUtils.* import io.shiftleft.codepropertygraph.generated.nodes.* import org.slf4j.LoggerFactory @@ -22,7 +24,7 @@ case class NodeTypeInfo( isField: Boolean = false, isStatic: Boolean = false ) -class Scope { +class Scope(implicit val withSchemaValidation: ValidationMode, val disableTypeFallback: Boolean) { private val logger = LoggerFactory.getLogger(this.getClass) private var scopeStack: List[JavaScopeElement] = Nil @@ -39,12 +41,19 @@ class Scope { scopeStack = new FieldDeclScope(isStatic, name) :: scopeStack } - def pushTypeDeclScope(typeDecl: NewTypeDecl, isStatic: Boolean, methodNames: Set[String] = Set.empty): Unit = { + def pushTypeDeclScope( + typeDecl: NewTypeDecl, + isStatic: Boolean, + outerClassGenericSignature: Option[String] = None, + methodNames: Set[String] = Set.empty, + recordParameters: List[Parameter] = Nil + ): Unit = { val captures = getCapturesForNewScope(isStatic) val outerClassType = scopeStack.takeUntil(_.isInstanceOf[TypeDeclScope]) match { case Nil => None - case (head: TypeDeclScope) :: Nil => Option.unless(isStatic)(head.typeDecl.fullName) + case (head: TypeDeclScope) :: Nil => + Option.unless(isStatic)(head.typeDecl.fullName) case head :: Nil => // make exhaustive match checker happy, but is impossible @@ -52,12 +61,20 @@ class Scope { case scopes => Option - .unless(isStatic || scopes.init.exists(_.isStatic)) { - scopes.lastOption.collectFirst { case typeDeclScope: TypeDeclScope => typeDeclScope.typeDecl.fullName } - } + .unless(isStatic || scopes.init.exists(_.isStatic))(scopes.lastOption.collectFirst { + case typeDeclScope: TypeDeclScope => typeDeclScope.typeDecl.fullName + }) .flatten } - scopeStack = new TypeDeclScope(typeDecl, isStatic, captures, outerClassType, methodNames) :: scopeStack + scopeStack = new TypeDeclScope( + typeDecl, + isStatic, + captures, + outerClassType, + outerClassGenericSignature, + methodNames, + recordParameters + ) :: scopeStack } def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = { @@ -74,19 +91,24 @@ class Scope { def popNamespaceScope(): NamespaceScope = popScope[NamespaceScope]() - private def popScope[ScopeType <: JavaScopeElement](): ScopeType = { + private def popScope[ScopeType0 <: JavaScopeElement](): ScopeType0 = { val scope = scopeStack.head scopeStack = scopeStack.tail - scope.asInstanceOf[ScopeType] + scope.asInstanceOf[ScopeType0] } def addTopLevelType(name: String, typeFullName: String): Unit = { - val scopeType = ScopeTopLevelType(typeFullName) + val scopeType = ScopeTopLevelType(typeFullName, name) scopeStack.head.addTypeToScope(name, scopeType) } - def addInnerType(name: String, typeFullName: String): Unit = { - val scopeType = ScopeInnerType(typeFullName) + def addInnerType(name: String, typeFullName: String, internalName: String): Unit = { + val scopeType = ScopeInnerType(typeFullName, internalName) + scopeStack.head.addTypeToScope(name, scopeType) + } + + def addTypeParameter(name: String, typeFullName: String): Unit = { + val scopeType = ScopeTypeParam(typeFullName, name) scopeStack.head.addTypeToScope(name, scopeType) } @@ -279,6 +301,14 @@ class Scope { case _ => None } } + + def addLocalsForPatternsToEnclosingBlock(patterns: List[TypePatternExpr]): Unit = { + patterns.flatMap(enclosingMethod.get.getPatternVariableInfo(_)).foreach { + case PatternVariableInfo(typePatternExpr, variableLocal, _, _, _, _) => + enclosingBlock.get.addPatternLocal(variableLocal, typePatternExpr) + } + } + } object Scope { @@ -295,14 +325,15 @@ object Scope { sealed trait ScopeType { def typeFullName: String + def name: String } /** Used for top-level type declarations and imports that do not have captures to be concerned about or synthetic * names in the cpg */ - final case class ScopeTopLevelType(override val typeFullName: String) extends ScopeType + final case class ScopeTopLevelType(override val typeFullName: String, override val name: String) extends ScopeType - final class ScopeInnerType(override val typeFullName: String) extends ScopeType { + final class ScopeInnerType(override val typeFullName: String, override val name: String) extends ScopeType { private val usedCaptures: mutable.ListBuffer[ScopeVariable] = mutable.ListBuffer() override def equals(other: Any): Boolean = { @@ -316,27 +347,39 @@ object Scope { } object ScopeInnerType { - def apply(typeFullName: String): ScopeInnerType = { - new ScopeInnerType(typeFullName) + def apply(typeFullName: String, name: String): ScopeInnerType = { + new ScopeInnerType(typeFullName, name) } } + final case class ScopeTypeParam(override val typeFullName: String, override val name: String) extends ScopeType + sealed trait ScopeVariable { def node: NewVariableNode def typeFullName: String def name: String + def genericSignature: String } final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable { - val typeFullName: String = node.typeFullName - val name: String = node.name + val typeFullName: String = node.typeFullName + val name: String = node.name + val genericSignature: String = node.genericSignature } - final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable { + final case class ScopeParameter(override val node: NewMethodParameterIn, override val genericSignature: String) + extends ScopeVariable { val typeFullName: String = node.typeFullName val name: String = node.name } final case class ScopeMember(override val node: NewMember, isStatic: Boolean) extends ScopeVariable { - val typeFullName: String = node.typeFullName - val name: String = node.name + val typeFullName: String = node.typeFullName + val name: String = node.name + val genericSignature: String = node.genericSignature + } + final case class ScopePatternVariable(override val node: NewLocal, typePatternExpr: TypePatternExpr) + extends ScopeVariable { + val typeFullName: String = node.typeFullName + val name: String = node.name + val genericSignature: String = node.genericSignature } sealed trait VariableLookupResult { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/JmodClassPath.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/JmodClassPath.scala index e234475e2508..a9891ff5aabf 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/JmodClassPath.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/JmodClassPath.scala @@ -1,10 +1,10 @@ package io.joern.javasrc2cpg.typesolvers import better.files.File -import io.joern.javasrc2cpg.typesolvers.JmodClassPath._ +import io.joern.javasrc2cpg.typesolvers.JmodClassPath.* import javassist.ClassPath -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Try import java.io.InputStream import java.net.URL diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeInfoCalculator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeInfoCalculator.scala index 5c47594f83a3..a37410c6384a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeInfoCalculator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeInfoCalculator.scala @@ -11,6 +11,7 @@ import com.github.javaparser.resolution.logic.InferenceVariableType import com.github.javaparser.resolution.model.typesystem.{LazyType, NullType} import com.github.javaparser.resolution.types.* import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap +import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserRecordDeclaration import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.{TypeConstants, TypeNameConstants} import io.joern.x2cpg.datastructures.Global import org.slf4j.LoggerFactory @@ -259,6 +260,8 @@ object TypeInfoCalculator { val Object: String = "java.lang.Object" val Class: String = "java.lang.Class" val Iterator: String = "java.util.Iterator" + val Enum: String = "java.lang.Enum" + val Record: String = "java.lang.Record" val Void: String = "void" val Any: String = "ANY" } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeSizeReducer.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeSizeReducer.scala index 11fa1fdd082f..5e1e0e250267 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeSizeReducer.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/TypeSizeReducer.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.typesolvers import com.github.javaparser.ast.body.TypeDeclaration import com.github.javaparser.ast.stmt.BlockStmt -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* object TypeSizeReducer { def simplifyType(typeDeclaration: TypeDeclaration[?]): Unit = { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/BindingTable.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/BindingTable.scala index 27d8c2634036..1e4c3f190960 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/BindingTable.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/BindingTable.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.util import scala.collection.mutable import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import io.joern.x2cpg.utils.NodeBuilders.newBindingNode -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import io.shiftleft.codepropertygraph.generated.EdgeTypes case class BindingTableEntry(name: String, signature: String, implementingMethodFullName: String) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Delombok.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Delombok.scala index 1d1ab11b932a..69b1354d07be 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Delombok.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Delombok.scala @@ -1,14 +1,14 @@ package io.joern.javasrc2cpg.util import better.files.File -import io.joern.x2cpg.utils.ExternalCommand import io.joern.javasrc2cpg.util.Delombok.DelombokMode.* +import io.joern.x2cpg.utils.ExternalCommand import org.slf4j.LoggerFactory -import java.nio.file.{Path, Paths} -import scala.collection.mutable -import scala.util.matching.Regex -import scala.util.{Failure, Success, Try} +import java.nio.file.Path +import scala.util.Failure +import scala.util.Success +import scala.util.Try object Delombok { @@ -53,8 +53,17 @@ object Delombok { System.getProperty("java.class.path") } val command = - s"$javaPath -cp $classPathArg lombok.launch.Main delombok ${inputPath.toAbsolutePath.toString} -d ${outputDir.canonicalPath}" - logger.debug(s"Executing delombok with command $command") + Seq( + javaPath, + "-cp", + classPathArg, + "lombok.launch.Main", + "delombok", + inputPath.toAbsolutePath.toString, + "-d", + outputDir.canonicalPath + ) + logger.debug(s"Executing delombok with command ${command.mkString(" ")}") command } @@ -72,6 +81,7 @@ object Delombok { Try(delombokTempDir.createChild(relativeOutputPath, asDirectory = true)).flatMap { packageOutputDir => ExternalCommand .run(delombokToTempDirCommand(inputDir, packageOutputDir, analysisJavaHome), ".") + .toTry .map(_ => delombokTempDir.path.toAbsolutePath.toString) } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/TemporaryNameProvider.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/TemporaryNameProvider.scala new file mode 100644 index 000000000000..3335c2bede60 --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/TemporaryNameProvider.scala @@ -0,0 +1,14 @@ +package io.joern.javasrc2cpg.util + +class TemporaryNameProvider { + + val tmpNamePrefix = "$obj" + private var tmpIndex: Int = 0 + + def next: String = { + val name = s"$tmpNamePrefix$tmpIndex" + tmpIndex += 1 + name + } + +} diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Util.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Util.scala index 8bc1e05e920d..9d844758da35 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Util.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/util/Util.scala @@ -2,15 +2,12 @@ package io.joern.javasrc2cpg.util import com.github.javaparser.resolution.declarations.ResolvedReferenceTypeDeclaration import com.github.javaparser.resolution.types.ResolvedReferenceType -import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants -import io.joern.x2cpg.{Ast, Defines} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, PropertyNames} -import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewFieldIdentifier, NewMember} +import io.joern.x2cpg.Defines import org.slf4j.LoggerFactory import scala.collection.mutable import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* object Util { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/JavaSrc2CpgTestContext.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/JavaSrc2CpgTestContext.scala deleted file mode 100644 index afc3a72a0da0..000000000000 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/JavaSrc2CpgTestContext.scala +++ /dev/null @@ -1,61 +0,0 @@ -package io.joern.javasrc2cpg - -import io.shiftleft.codepropertygraph.generated.Cpg -import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic -import io.joern.x2cpg.X2Cpg.writeCodeToFile -import io.shiftleft.semanticcpg.layers.LayerCreatorContext - -class JavaSrc2CpgTestContext { - private var code: String = "" - private var buildResult = Option.empty[Cpg] - private var _extraFlows = List.empty[FlowSemantic] - - def buildCpg(runDataflow: Boolean, inferenceJarPaths: Set[String]): Cpg = { - if (buildResult.isEmpty) { - val javaSrc2Cpg = JavaSrc2Cpg() - val config = Config(inferenceJarPaths = inferenceJarPaths) - .withInputPath(writeCodeToFile(code, "javasrc2cpgTest", ".java").getAbsolutePath) - .withOutputPath("") - .withCacheJdkTypeSolver(true) - val cpg = javaSrc2Cpg.createCpgWithOverlays(config) - if (runDataflow) { - val context = new LayerCreatorContext(cpg.get) - val options = new OssDataFlowOptions(extraFlows = _extraFlows) - new OssDataFlow(options).run(context) - } - buildResult = Some(cpg.get) - } - buildResult.get - } - - private def withSource(code: String): JavaSrc2CpgTestContext = { - this.code = code - this - } - - private def withExtraFlows(value: List[FlowSemantic] = List.empty): this.type = { - this._extraFlows = value - this - } - -} - -object JavaSrc2CpgTestContext { - def buildCpg(code: String, inferenceJarPaths: Set[String] = Set.empty): Cpg = { - new JavaSrc2CpgTestContext() - .withSource(code) - .buildCpg(runDataflow = false, inferenceJarPaths = inferenceJarPaths) - } - - def buildCpgWithDataflow( - code: String, - inferenceJarPaths: Set[String] = Set.empty, - extraFlows: List[FlowSemantic] = List.empty - ): Cpg = { - new JavaSrc2CpgTestContext() - .withSource(code) - .withExtraFlows(extraFlows) - .buildCpg(runDataflow = true, inferenceJarPaths = inferenceJarPaths) - } -} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/config/ConfigTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/config/ConfigTests.scala index 7eeb1c82bb4a..84a6036f6da0 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/io/JavaSrc2CpgHTTPServerTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/io/JavaSrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..773fce3a89c7 --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/io/JavaSrc2CpgHTTPServerTests.scala @@ -0,0 +1,85 @@ +package io.joern.javasrc2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class JavaSrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("javasrc2cpgTestsHttpTest") + val file = dir / "Main.java" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |class HelloWorld { + | public static void main(String[] args) { + | System.out.println($indexStr); + | } + |} + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.javasrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.javasrc2cpg.Main.stop() + } + + "Using javasrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("javasrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("System.out.println()") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("javasrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain(s"System.out.println($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/passes/ConfigFileCreationPassTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/passes/ConfigFileCreationPassTests.scala index 2891f13b3eaa..d48a1a873f6e 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/passes/ConfigFileCreationPassTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/passes/ConfigFileCreationPassTests.scala @@ -1,14 +1,13 @@ package io.joern.javasrc2cpg.passes import better.files.File +import flatgraph.misc.TestUtils.* import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.x2cpg.passes.frontend.JavaConfigFileCreationPass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.NewMetaData import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.ProjectRoot -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder import java.nio.file.Paths @@ -18,8 +17,8 @@ class ConfigFileCreationPassTests extends JavaSrcCode2CpgFixture { ProjectRoot.relativise("joern-cli/frontends/javasrc2cpg/src/test/resources/config_tests") "it should find the correct config files" in { - val cpg = new Cpg() - BatchedUpdate.applyDiff(cpg.graph, Cpg.newDiffGraphBuilder.addNode(NewMetaData().root(testConfigDir)).build()) + val cpg = Cpg.from(_.addNode(NewMetaData().root(testConfigDir))) + val foundFiles = new JavaConfigFileCreationPass(cpg).generateParts().map(_.canonicalPath) val absoluteConfigDir = File(testConfigDir).canonicalPath foundFiles should contain theSameElementsAs Array( diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnnotationTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnnotationTests.scala index 50193d76e07f..77b5f6352c01 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnnotationTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnnotationTests.scala @@ -5,6 +5,29 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class AnnotationTests extends JavaSrcCode2CpgFixture { + "annotations that cannot be resolved from imports" should { + val cpg = code(""" + |package foo; + | + |@interface TestMarker {} + |""".stripMargin) + .moreCode(""" + |package bar; + | + |import foo.*; + |import bar.*; + | + |public class Bar { + | @TestMarker + | public void bar() {} + |} + |""".stripMargin) + + "have the annotation type be resolved" in { + cpg.method.name("bar").annotation.fullName.l shouldBe List("foo.TestMarker") + } + } + "normal value annotations" should { lazy val cpg = code(""" |import some.NormalAnnotation; diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnonymousClassTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnonymousClassTests.scala index 8e5c820b37ac..edc4fe60bd9d 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnonymousClassTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/AnonymousClassTests.scala @@ -3,11 +3,53 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.JavaSrc2Cpg import io.shiftleft.semanticcpg.language.* import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Binding, Block, Call, FieldIdentifier, Identifier, TypeDecl} +import io.shiftleft.codepropertygraph.generated.nodes.{ + Binding, + Block, + Call, + FieldIdentifier, + Identifier, + TypeDecl, + TypeRef +} import io.shiftleft.codepropertygraph.generated.Operators class AnonymousClassTests extends JavaSrcCode2CpgFixture { + "mixed static/non-static anonymous classes with the same name as children of lambdas" should { + val cpg = code(""" + |package foo; + | + |public class Foo { + | + | private static FirstProvider method1() { + | return firstTask -> { + | firstTask.doFirst(new Action() { }); + | }; + | } + | + | private SecondProvider method2() { + | return secondTask -> { + | secondTask.doSecond(new Action() { }); + | }; + | } + |} + | + |""".stripMargin) + + "have the correct names" in { + cpg.typeDecl.name(".*Action.*").fullName.sorted.l shouldBe List( + "foo.Foo.0.Action$0", + "foo.Foo.1.Action$0" + ) + } + + "not result in any orphan locals" in { + !cpg.local.exists(_._astIn.isEmpty) shouldBe true + } + + } + "simple anonymous classes extending interfaces in method bodies" should { val cpg = code(""" |package foo; @@ -178,10 +220,8 @@ class AnonymousClassTests extends JavaSrcCode2CpgFixture { fieldAccess.name shouldBe Operators.fieldAccess fieldAccess.typeFullName shouldBe "foo.Bar" - inside(fieldAccess.argument.l) { case List(fooIdentifier: Identifier, bField: FieldIdentifier) => - fooIdentifier.name shouldBe "Foo" - fooIdentifier.typeFullName shouldBe "foo.Foo" - + inside(fieldAccess.argument.l) { case List(typeRef: TypeRef, bField: FieldIdentifier) => + typeRef.typeFullName shouldBe "foo.Foo" bField.canonicalName shouldBe "b" } } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArithmeticOperationsTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArithmeticOperationsTests.scala index e804e57eee6f..6e33a3755eef 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArithmeticOperationsTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArithmeticOperationsTests.scala @@ -5,7 +5,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.Identifier import io.shiftleft.semanticcpg.language.toNodeTypeStarters -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ArithmeticOperationsTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArrayTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArrayTests.scala index e1f247b52fd8..1c44df3ef426 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArrayTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ArrayTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ArrayTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BindingTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BindingTests.scala index e7904883cec5..153b99d7e4a1 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BindingTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BindingTests.scala @@ -1,6 +1,6 @@ package io.joern.javasrc2cpg.querying -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture class BindingTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BooleanOperationsTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BooleanOperationsTests.scala index 7ec4357a2891..76e69e381f6c 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BooleanOperationsTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/BooleanOperationsTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class BooleanOperationsTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallGraphTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallGraphTests.scala index c54121618cf8..1c680d2d689b 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallGraphTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallGraphTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallGraphTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala index 3072ad5574d3..9e62224b211f 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala @@ -6,9 +6,7 @@ import io.shiftleft.codepropertygraph.generated.edges.Ref import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal, MethodParameterIn} import io.shiftleft.semanticcpg.language.NoResolve -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.jIteratortoTraversal -import overflowdb.traversal.toNodeTraversal +import io.shiftleft.semanticcpg.language.* class NewCallTests extends JavaSrcCode2CpgFixture { "calls to imported methods" when { @@ -284,7 +282,7 @@ class NewCallTests extends JavaSrcCode2CpgFixture { cpg.method.name("test").call.name("foo").argument(0).outE.collectAll[Ref].l match { case List(ref) => - ref.inNode match { + ref.dst match { case param: MethodParameterIn => param.name shouldBe "this" param.index shouldBe 0 @@ -309,7 +307,7 @@ class NewCallTests extends JavaSrcCode2CpgFixture { cpg.method.name("test").call.name("foo").argument(0).outE.collectAll[Ref].l match { case List(ref) => - ref.inNode match { + ref.dst match { case param: MethodParameterIn => param.name shouldBe "this" param.index shouldBe 0 diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CfgTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CfgTests.scala index 5bcf4d6c7fec..9b7e75ec1bd0 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CfgTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CfgTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CfgTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ClassLoaderTypeTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ClassLoaderTypeTests.scala index 1178f9113a77..ceb60c86b9a5 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ClassLoaderTypeTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ClassLoaderTypeTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.Config import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.utils.ExternalCommand class ClassLoaderTypeTests extends JavaSrcCode2CpgFixture { @@ -48,7 +48,6 @@ class ClassLoaderTypeTests extends JavaSrcCode2CpgFixture { } "be resolved by the system classloader (java 17)" in { - println(System.getProperty("java.version")) val cpg = code(testCode) cpg.call.name("getIconHeight").methodFullName.head.startsWith(" val List(_: Local, assign: Call, init: Call) = method.astChildren.isBlock.astChildren.l: @unchecked - assign.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + assign.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() assign.name shouldBe Operators.assignment val alloc = assign.argument(2).asInstanceOf[Call] alloc.name shouldBe ".alloc" alloc.code shouldBe "new Bar(4, 2)" - alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() alloc.methodFullName shouldBe ".alloc" alloc.typeFullName shouldBe "Bar" alloc.argument.size shouldBe 0 init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" - init.callOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init._methodViaCallOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.typeFullName shouldBe "void" init.signature shouldBe "void(int,int)" init.code shouldBe "new Bar(4, 2)" @@ -303,20 +303,20 @@ class ConstructorInvocationTests extends JavaSrcCode2CpgFixture { case List(method) => val List(assign: Call, init: Call) = method.astChildren.isBlock.astChildren.l: @unchecked - assign.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + assign.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() assign.name shouldBe Operators.assignment val alloc = assign.argument(2).asInstanceOf[Call] alloc.name shouldBe ".alloc" alloc.code shouldBe "new Bar(4, 2)" - alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() alloc.methodFullName shouldBe ".alloc" alloc.typeFullName shouldBe "Bar" alloc.argument.size shouldBe 0 init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" - init.callOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init._methodViaCallOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int,int)" + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.typeFullName shouldBe "void" init.signature shouldBe "void(int,int)" init.code shouldBe "new Bar(4, 2)" @@ -362,16 +362,16 @@ class ConstructorInvocationTests extends JavaSrcCode2CpgFixture { alloc.order shouldBe 2 alloc.argumentIndex shouldBe 2 alloc.code shouldBe "new Bar(42)" - alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() alloc.typeFullName shouldBe "Bar" alloc.argument.size shouldBe 0 init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" - init.callOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" + init._methodViaCallOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" init.signature shouldBe "void(int)" init.code shouldBe "new Bar(42)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.argument.size shouldBe 2 val List(obj: Identifier, initArg1: Literal) = init.argument.l: @unchecked @@ -411,16 +411,16 @@ class ConstructorInvocationTests extends JavaSrcCode2CpgFixture { alloc.order shouldBe 2 alloc.argumentIndex shouldBe 2 alloc.code shouldBe "new Bar(42)" - alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + alloc.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() alloc.typeFullName shouldBe "Bar" alloc.argument.size shouldBe 0 init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" - init.callOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" + init._methodViaCallOut.head.fullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" init.signature shouldBe "void(int)" init.code shouldBe "new Bar(42)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.argument.size shouldBe 2 val List(obj: Identifier, initArg1: Literal) = init.argument.l: @unchecked @@ -447,7 +447,7 @@ class ConstructorInvocationTests extends JavaSrcCode2CpgFixture { val List(init: Call) = method.astChildren.isBlock.astChildren.l: @unchecked init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Bar.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.typeFullName shouldBe "void" init.signature shouldBe "void(int)" @@ -475,7 +475,7 @@ class ConstructorInvocationTests extends JavaSrcCode2CpgFixture { val List(init: Call) = method.astChildren.isBlock.astChildren.l: @unchecked init.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName init.methodFullName shouldBe s"Foo.${io.joern.x2cpg.Defines.ConstructorMethodName}:void(int)" - init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.toString + init.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH.name() init.typeFullName shouldBe "void" init.signature shouldBe "void(int)" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ControlStructureTests.scala index e1fbe59ce842..189933d2673a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ControlStructureTests.scala @@ -11,12 +11,12 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ Identifier, Literal, Local, - Return + Return, + TypeRef } -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.toNodeTraversal +import io.shiftleft.semanticcpg.language.* -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class NewControlStructureTests extends JavaSrcCode2CpgFixture { @@ -156,10 +156,8 @@ class NewControlStructureTests extends JavaSrcCode2CpgFixture { fieldAccess.name shouldBe Operators.fieldAccess fieldAccess.typeFullName shouldBe "java.lang.String[]" - inside(fieldAccess.argument.l) { case List(barIdentifier: Identifier, staticArr: FieldIdentifier) => - barIdentifier.name shouldBe "Bar" - barIdentifier.typeFullName shouldBe "Bar" - + inside(fieldAccess.argument.l) { case List(barTypeRef: TypeRef, staticArr: FieldIdentifier) => + barTypeRef.typeFullName shouldBe "Bar" staticArr.canonicalName shouldBe "STATIC_ARR" } } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/EnumTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/EnumTests.scala index 84e8ed5e3df6..e4852d019694 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/EnumTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/EnumTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.Literal -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class EnumTests extends JavaSrcCode2CpgFixture { val cpg = code(""" @@ -24,6 +24,10 @@ class EnumTests extends JavaSrcCode2CpgFixture { |} |""".stripMargin) + "the enum type should extends java.lang.Enum" in { + cpg.typeDecl.name("FuzzyBool").inheritsFromTypeFullName.l shouldBe List("java.lang.Enum") + } + "it should parse a basic enum without values" in { inside(cpg.typeDecl.name(".*FuzzyBool.*").l) { case List(typeDecl) => typeDecl.code shouldBe "public enum FuzzyBool" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala index d1225cd11519..b9fbc84ea31b 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, TypeRef} import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends JavaSrcCode2CpgFixture { @@ -82,4 +82,28 @@ class FieldAccessTests extends JavaSrcCode2CpgFixture { access.referencedMember.name.head shouldBe "value" } + "correctly handle access to statically imported field" in { + val cpg = code(""" + |import static Bar.STATIC_INT; + |public class Foo { + | public void foo() { + | int x = STATIC_INT; + | } + |} + |""".stripMargin) + .moreCode( + """ + |public class Bar { + | public static int STATIC_INT = 111; + |} + |""".stripMargin, + fileName = "Bar.java" + ) + + val List(assignment) = cpg.call.code("int x = STATIC_INT").l + val fieldAccess = assignment.argument(2).asInstanceOf[Call] + val typeRef = fieldAccess.argument(1).asInstanceOf[TypeRef] + typeRef.typeFullName shouldBe "Bar" + } + } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FileTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FileTests.scala index bd64f35330a9..728ffbdb36f1 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FileTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FileTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import org.scalatest.Ignore diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericSignatureTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericSignatureTests.scala new file mode 100644 index 000000000000..52a4ccff727a --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericSignatureTests.scala @@ -0,0 +1,1084 @@ +package io.joern.javasrc2cpg.querying + +import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +/** ### Class type signatures + * - In most cases, only the simple name for the class will be used (so `LString;` will be used instead of + * `Ljava/lang/String`) + * + * - Where a qualified name is used in source, that name is used verbatim in the signature, for example + * `Ljava.util.List` (note the `.` were not substituted for `/`. + * + * - For local classes, the name of the class as it appears in the CPG is used in the signature for instances of that + * class (we don't follow the JVM naming scheme for these), for example + * `Ltestpackage.TestClass.testMethod.LocalClass;` + * + * ### Type parameter bounds From the language specification: + * ``` + * TypeParameter: + * Identifier ClassBound {InterfaceBound} + * + * ClassBound: + * : [ReferenceTypeSignature] + * + * InterfaceBound: + * : ReferenceTypeSignature + * ``` + * If a type parameter only has interface bounds I1, I2, ..., then the signature should contain `` (note + * the empty class bound), but in general we won't know if a type is a class or interface without resolving the it, so + * the signature in the CPG will contain `` instead. + * + * ### Unspecified types Where no type name is specified, the special `L__unspecified_type;` type is used in generic + * signatures. This happens in a few places: + * - For lambda return types and lambda parameters which do not have explicit type annotations + * + * - For lambda type decls + * + * - For locals with a `var` type, for example `var x = 42` + * + * - For synthetic locals created for `foreach` loops, for example in `for (String item : items())`, we create a + * temporary `String[] $iterLocal0 = items()` local which will have an unspecified signature (`item` will still + * have the signature `LString;` as expected) + * + * - For synthetic locals created for the LHS of `instanceof` expressions with pattern matching, for example in + * `foo() instanceof String s`, we create an `Object o = foo()` local (since the type depends on the return type of + * `foo`). + */ +class GenericSignatureTests extends JavaSrcCode2CpgFixture { + + "a simple example with primitive types" should { + val cpg = code(""" + |package test; + | + |class Test { + | char charMember; + | + | public void test(boolean b) { + | int x; + | } + |} + |""".stripMargin) + + "have the correct generic signature for locals" in { + cpg.local.genericSignature.l shouldBe List("I") + } + + "have the correct generic signature for a void method with a boolean arg" in { + cpg.method.name("test").genericSignature.l shouldBe List("(Z)V") + } + + "have the correct generic signature for the parameter" in { + cpg.member.name("charMember").genericSignature.l shouldBe List("C") + } + + "have the correct generic signature for the type decl implicitly extending java.lang.Object" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LObject;") + } + } + + "a method with parameters and a non-void return type" should { + val cpg = code(""" + |package test; + | + |class Test { + | public String test(Test t, Integer i) { + | return null; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(LTest;LInteger;)LString;") + } + } + + "a method with an unresolved return type" should { + val cpg = code(""" + |package test; + | + |class Test { + | public Foo test(Test t) { + | return null; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(LTest;)LFoo;") + } + } + + "a method with an unresolved parameter" should { + val cpg = code(""" + |package test; + | + |class Test { + | public void test(Foo f) { + | return null; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(LFoo;)V") + } + } + + "a class extending another class" should { + val cpg = code(""" + |package foo; + | + |class Foo {} + |""".stripMargin).moreCode(""" + |package test; + | + |import foo.Foo; + | + |class Test extends Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LFoo;") + } + } + + "a class implementing an interface" should { + val cpg = code(""" + |package foo; + | + |interface Foo {} + |""".stripMargin).moreCode(""" + |package test; + | + |import foo.Foo; + | + |class Test implements Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LObject;LFoo;") + } + } + + "a class extending another class and implementing an interface" should { + val cpg = code(""" + |package foo; + | + |class Foo {} + |""".stripMargin) + .moreCode(""" + |package bar; + | + |interface Bar {} + |""".stripMargin) + .moreCode(""" + |package test; + | + |import foo.Foo; + |import bar.Bar; + | + |class Test extends Foo implements Bar {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LFoo;LBar;") + } + } + + "a class implementing multiple interfaces" should { + val cpg = code(""" + |package foo; + | + |interface Foo {} + |""".stripMargin) + .moreCode(""" + |package bar; + | + |interface Bar {} + |""".stripMargin) + .moreCode(""" + |package test; + | + |import foo.Foo; + |import bar.Bar; + | + |class Test implements Foo, Bar {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LObject;LFoo;LBar;") + } + } + + "an interface not extending another interface" should { + val cpg = code(""" + |package foo; + | + |interface Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Foo").genericSignature.l shouldBe List("LObject;") + } + } + + "an interface extending another interface" should { + val cpg = code(""" + |package foo; + | + |interface Foo {} + |""".stripMargin).moreCode(""" + |package bar; + | + |import foo.Foo; + | + |interface Bar extends Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Bar").genericSignature.l shouldBe List("LObject;LFoo;") + } + } + + "an interface extending multiple interfaces" should { + val cpg = code(""" + |package foo; + | + |interface Foo {} + |""".stripMargin) + .moreCode(""" + |package bar; + | + |interface Bar {} + |""".stripMargin) + .moreCode(""" + |package test; + | + |import foo.Foo; + |import bar.Bar; + | + |interface Test extends Foo, Bar {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LObject;LFoo;LBar;") + } + } + + "a class extending an unresolved class" should { + val cpg = code(""" + |package test; + | + |class Test extends Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LFoo;") + } + } + + "a class implementing an unresolved interface" should { + val cpg = code(""" + |package test; + | + |class Test implements Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LObject;LFoo;") + } + } + + "a resolved lambda method" should { + val cpg = code(""" + |package test; + | + |import java.util.function.Consumer; + | + |class Test { + | public Consumer test() { + | return s -> System.out.println(s); + | } + |} + |""".stripMargin) + + "have the correct generic signature for the lambda method" in { + cpg.method.name(".*lambda.*").genericSignature.l shouldBe List("(L__unspecified_type;)L__unspecified_type;") + } + + "have an empty generic signature for the lambda type decl" in { + cpg.typeDecl.name(".*lambda.*").genericSignature.l shouldBe List("L__unspecified_type;") + } + } + + "a lambda method with an explicit type annotation" should { + val cpg = code(""" + |package test; + | + |import java.util.function.Consumer; + | + |class Test { + | public Consumer test() { + | return (String s) -> System.out.println(s); + | } + |} + |""".stripMargin) + + "have the correct generic signature for the lambda method" in { + cpg.method.name(".*lambda.*").genericSignature.l shouldBe List("(LString;)L__unspecified_type;") + } + + "have an empty generic signature for the lambda type decl" in { + cpg.typeDecl.name(".*lambda.*").genericSignature.l shouldBe List("L__unspecified_type;") + } + } + + "an unresolved lambda method" should { + val cpg = code(""" + |package test; + | + |class Test { + | public Consumer test() { + | return s -> System.out.println(s); + | } + |} + |""".stripMargin) + + "have the correct generic signature for the lambda method" in { + cpg.method.name(".*lambda.*").genericSignature.l shouldBe List("(L__unspecified_type;)L__unspecified_type;") + } + + "have an empty generic signature for the lambda type decl" in { + cpg.typeDecl.name(".*lambda.*").genericSignature.l shouldBe List("L__unspecified_type;") + } + } + + "a nested class" should { + val cpg = code(""" + |package test; + | + |class Test { + | class Nested {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Test$Nested").genericSignature.l shouldBe List("LObject;") + } + } + + "a local class" should { + val cpg = code(""" + |package test; + | + |class Test { + | public void test() { + | class Local {} + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Local").genericSignature.l shouldBe List("LObject;") + } + } + + "an anonymous class extending a resolved class" should { + val cpg = code(""" + |package foo; + | + |class Foo {} + |""".stripMargin).moreCode(""" + |package test; + | + |import foo.Foo; + | + |class Test { + | public void test() { + | Foo f = new Foo() {}; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Foo$0").genericSignature.l shouldBe List("LFoo;") + } + } + + "an anonymous class extending an unresolved class" should { + val cpg = code(""" + |package test; + | + |class Test { + | public void test() { + | Foo f = new Foo() {}; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Foo$0").genericSignature.l shouldBe List("LFoo;") + } + } + + "an anonymous class extending an unresolved class which can be resolved from imports" should { + val cpg = code(""" + |package test; + | + |import foo.Foo; + | + |class Test { + | public void test() { + | Foo f = new Foo() {}; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Foo$0").genericSignature.l shouldBe List("LFoo;") + } + } + + "a local with an array type" should { + val cpg = code(""" + |package test; + | + |class Test { + | public void test() { + | String[] items; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("items").genericSignature.l shouldBe List("[LString;") + } + } + + "a generic local with a single type argument" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public void test() { + | List list; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("list").genericSignature.l shouldBe List("LList;") + } + } + + "a generic local with a single wildcard type argument" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public void test() { + | List list; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("list").genericSignature.l shouldBe List("LList<*>;") + } + } + + "a generic local with a single wildcard type argument with an upper bound" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public void test() { + | List list; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("list").genericSignature.l shouldBe List("LList<+LString;>;") + } + } + + "a generic local with a single wildcard type argument with a lower bound" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public void test() { + | List list; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("list").genericSignature.l shouldBe List("LList<-LString;>;") + } + } + + "a generic local with multiple type arguments" should { + val cpg = code(""" + |package test; + | + |import java.util.Map; + | + |class Test { + | public void test() { + | Map map; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("map").genericSignature.l shouldBe List("LMap;") + } + } + + "a generic local with nested type arguments" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + |import java.util.Map; + | + |class Test { + | public void test() { + | Map> map; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("map").genericSignature.l shouldBe List("LMap;>;") + } + } + + "a generic local with a type variable type from the method signature" should { + val cpg = code(""" + |package test; + | + |class Test { + | public void test() { + | T t; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("t").genericSignature.l shouldBe List("TT;") + } + } + + "a generic local with a nested type variable type from the method signature" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public void test() { + | List list; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("list").genericSignature.l shouldBe List("LList;") + } + } + + "a generic local with a type variable type from the class signature" should { + val cpg = code(""" + |import java.util.List; + | + |public class Main { + | public void main(String[] args) { + | T t; + | } + |} + | + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("t").genericSignature.l shouldBe List("TT;") + } + } + + "a generic local with a type variable type as a bound from the class signature" should { + val cpg = code(""" + |import java.util.List; + | + |public class Main { + | public void main(String[] args) { + | List t; + | } + |} + | + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("t").genericSignature.l shouldBe List("LList<+TT;>;") + } + } + + "a method with a generic return type and generic parameters" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public S test(T t) {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(TT;)TS;") + } + } + + "a type parameter with multiple interface bounds" should { + val cpg = code(""" + |package test; + | + |interface I1 {} + |interface I2 {} + | + |class Test { + | public void test(T t) {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(TT;)V") + } + } + + "a generic member" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | public List list; + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.member.name("list").genericSignature.l shouldBe List("LList;") + } + } + + "an enum typeDecl" should { + val cpg = code(""" + |package test; + | + |enum Test { + | TEST + |} + |""".stripMargin) + + "have the correct generic signature for the type decl" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LEnum;") + } + + "have the correct generic signature for the enum constant" in { + cpg.member.name("TEST").genericSignature.l shouldBe List("LTest;") + } + } + + "a record typeDecl" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |record Test(String value, List list) {} + |""".stripMargin) + + "have the correct generic signature for the record" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LRecord;") + } + + "have the correct generic signature for the record parameter fields" in { + cpg.member.name("value").genericSignature.l shouldBe List("LString;") + cpg.member.name("list").genericSignature.l shouldBe List("LList;") + } + + "have the correct generic signature for the default constructor" in { + cpg.method.nameExact("").genericSignature.l shouldBe List("(LString;LList;)V") + } + + "have the correct generic signature for the record paramater accessors" in { + cpg.method.name("value").genericSignature.l shouldBe List("()LString;") + cpg.method.name("list").genericSignature.l shouldBe List("()LList;") + } + } + + "a type decl extending a generic type" should { + val cpg = code(""" + |package bar; + | + |class Bar {} + |""".stripMargin).moreCode(""" + |package test; + | + |class Test extends Bar {} + |""".stripMargin) + + "have the correct generic signature for the generic class" in { + cpg.typeDecl.name("Bar").genericSignature.l shouldBe List("LObject;") + } + + "have the correct generic signature for the inheriting class" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LBar;") + } + } + + "the lowering for a native foreach loop with a synthetic iterator local" should { + val cpg = code(""" + |package test; + | + |class Test { + | String[] items() { return null; } + | void test() { + | for (String item : items()) {} + | } + |} + |""".stripMargin) + + "have the correct generic signature for the synthetic iterator local" in { + cpg.local.nameExact("$iterLocal0").genericSignature.l shouldBe List("L__unspecified_type;") + } + + "have the correct generic signature for the synthetic index local" in { + cpg.local.nameExact("$idx0").genericSignature.l shouldBe List("I") + } + + "have the correct generic signature for the variable local" in { + cpg.local.name("item").genericSignature.l shouldBe List("LString;") + } + } + + "the lowering for a native foreach loop" should { + val cpg = code(""" + |package test; + | + |class Test { + | void test(String[] items) { + | for (String item : items) {} + | } + |} + |""".stripMargin) + + "have the correct generic signature for the synthetic index local" in { + cpg.local.nameExact("$idx0").genericSignature.l shouldBe List("I") + } + + "have the correct generic signature for the variable local" in { + cpg.local.name("item").genericSignature.l shouldBe List("LString;") + } + } + + "the lowering for an iterator foreach loop" should { + val cpg = code(""" + |package test; + | + |import java.util.List; + | + |class Test { + | void test(List items) { + | for (String item : items) {} + | } + |} + |""".stripMargin) + + "have the correct generic signature for the synthetic iterator local" in { + cpg.local.nameExact("$iterLocal0").genericSignature.l shouldBe List("Ljava.util.Iterator;") + } + + "have the correct generic signature for the variable local" in { + cpg.local.name("item").genericSignature.l shouldBe List("LString;") + } + } + + "the synthetic tmp local in the block representation of a constructor invocation" should { + val cpg = code(""" + |package test; + | + |class Test { + | void test() { + | System.out.println(new String()); + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name.foreach(println) + cpg.local.nameExact("$obj0").genericSignature.l shouldBe List("LString;") + } + } + + "a captured local in a lambda" should { + val cpg = code(""" + |package test; + | + |import java.util.function.Consumer; + | + |class Test { + | public Consumer test(Integer captured) { + | return s -> System.out.println(captured); + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("captured").genericSignature.l shouldBe List("LInteger;") + } + } + + "a pattern initializer requiring a tmp local" should { + val cpg = code(""" + |package test; + | + |class Test { + | public Object foo() { return null; } + | + | public void test() { + | if (foo() instanceof String s) {} + | } + |} + |""".stripMargin) + + "have the correct generic signature for the tmp local" in { + cpg.local.nameExact("$obj0").genericSignature.l shouldBe List("L__unspecified_type;") + } + + "have the correct generic signature for the pattern variable local" in { + cpg.local.name("s").genericSignature.l shouldBe List("LString;") + } + } + + "a local class with captures" should { + val cpg = code(""" + |class Test { + | String mainField; + | + | public void test(Integer testParam) { + | class Foo { + | void foo() { + | System.out.println(mainField + testParam); + | } + | } + | } + |} + |""".stripMargin) + + "have the correct generic signature for the outerClass member" in { + // TODO: This should maybe be `LTest;` instead, but the type variable has no + // meaning in `Foo`. + cpg.member.name("outerClass").genericSignature.l shouldBe List("LTest;") + } + + "have the correct generic signature for a captured parameter member" in { + cpg.member.name("testParam").genericSignature.l shouldBe List("LInteger;") + } + } + + "a class extending a nested class" should { + val cpg = code(""" + |package test; + | + |class Test { + | class Foo {} + | class Bar extends Foo {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Test$Bar").genericSignature.l shouldBe List("LTest$Foo;") + } + } + + "a class extending a local class" should { + val cpg = code(""" + |class Test { + | public void test() { + | class Foo {} + | class Bar extends Foo {} + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.nameExact("Bar").genericSignature.l shouldBe List("LTest.test.Foo;") + } + } + + "a default constructor" should { + val cpg = code(""" + |class Test {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.nameExact("").genericSignature.l shouldBe List("()V") + } + } + + "an explicit constructor" should { + val cpg = code(""" + |class Test { + | public Test(String s) {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.nameExact("").genericSignature.l shouldBe List("(LString;)V") + } + } + + "a compact constructor for a record" should { + val cpg = code(""" + |record Test(String s) { + | public Test {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.nameExact("").genericSignature.l shouldBe List("(LString;)V") + } + } + + "a local with an unresolved fully qualified name" should { + val cpg = code(""" + |class Test { + | public void test() { + | foo.Foo f; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("f").genericSignature.l shouldBe List("Lfoo.Foo;") + } + } + + "a local with an unresolved type which can be inferred from imports" should { + val cpg = code(""" + |import foo.Foo; + | + |class Test { + | public void test() { + | Foo f; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("f").genericSignature.l shouldBe List("LFoo;") + } + } + + "a member with an unresolved fully qualified name" should { + val cpg = code(""" + |class Test { + | foo.Foo f; + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.member.name("f").genericSignature.l shouldBe List("Lfoo.Foo;") + } + } + + "a member with an unresolved type which can be inferred from imports" should { + val cpg = code(""" + |import foo.Foo; + | + |class Test { + | Foo f; + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.member.name("f").genericSignature.l shouldBe List("LFoo;") + } + } + + "a method with an unresolved fully qualified return type and param" should { + val cpg = code(""" + |class Test { + | public foo.Foo test(bar.Bar b) {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(Lbar.Bar;)Lfoo.Foo;") + } + } + + "a method with an unresolved return type and param which can be inferred from imports" should { + val cpg = code(""" + |import foo.Foo; + |import bar.Bar; + | + |class Test { + | public Foo test(Bar b) {} + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.method.name("test").genericSignature.l shouldBe List("(LBar;)LFoo;") + } + } + + "a type decl extending an unresolved fully qualified type" should { + val cpg = code(""" + |class Test extends foo.Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("Lfoo.Foo;") + } + } + + "a type decl extending an unresolved type which can be inferred from imports" should { + val cpg = code(""" + |import foo.Foo; + |import bar.Bar; + | + |class Test extends Foo {} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.typeDecl.name("Test").genericSignature.l shouldBe List("LFoo;") + } + } + + "a local with a var type" should { + val cpg = code(""" + |public class Test { + | public void foo() { + | var s = "hello"; + | } + |} + |""".stripMargin) + + "have the correct generic signature" in { + cpg.local.name("s").genericSignature.l shouldBe List("L__unspecified_type;") + } + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala index 609e8789244e..8620f7318b31 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class GenericsTests extends JavaSrcCode2CpgFixture { "unresolved generic type declarations" should { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ImportTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ImportTests.scala index 8e01a0910bd9..aa26b95d1eab 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ImportTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ImportTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ImportTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/InferenceJarTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/InferenceJarTests.scala index 1d6b65f6dc97..8e77f6af02fe 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/InferenceJarTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/InferenceJarTests.scala @@ -1,16 +1,14 @@ package io.joern.javasrc2cpg.querying -import io.joern.javasrc2cpg.JavaSrc2CpgTestContext -import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants +import io.joern.javasrc2cpg.JavaSrc2Cpg.DefaultConfig +import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.x2cpg.Defines -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.ProjectRoot -import org.scalatest.freespec.AnyFreeSpec -import org.scalatest.matchers.should.Matchers -class InferenceJarTests extends AnyFreeSpec with Matchers { +class InferenceJarTests extends JavaSrcCode2CpgFixture { - private val code: String = + private val _code: String = """ |class Test { | public void test1() { @@ -19,18 +17,18 @@ class InferenceJarTests extends AnyFreeSpec with Matchers { |} |""".stripMargin - "CPG for code where inference jar for dependencies is provided" - { + "CPG for code where inference jar for dependencies is provided" should { val inferenceJarPath = ProjectRoot.relativise("joern-cli/frontends/javasrc2cpg/src/test/resources/Deps.jar") - lazy val cpg = JavaSrc2CpgTestContext.buildCpg(code, inferenceJarPaths = Set(inferenceJarPath)) + lazy val cpg = code(_code).withConfig(DefaultConfig.withInferenceJarPaths(Set(inferenceJarPath))) - "it should resolve the type for Deps" in { + "resolve the type for Deps" in { val call = cpg.method.name("test1").call.name("foo").head call.methodFullName shouldBe "Deps.foo:int()" call.typeFullName shouldBe "int" call.signature shouldBe "int()" } - "it should create stubs for elements used in Deps" in { + "create stubs for elements used in Deps" in { cpg.typeDecl.name("Deps").size shouldBe 1 val depsTypeDecl = cpg.typeDecl.name("Deps").head depsTypeDecl.fullName shouldBe "Deps" @@ -43,10 +41,10 @@ class InferenceJarTests extends AnyFreeSpec with Matchers { } } - "CPG for code where inference jar for dependencies is not provided" - { - lazy val cpg = JavaSrc2CpgTestContext.buildCpg(code) + "CPG for code where inference jar for dependencies is not provided" should { + lazy val cpg = code(_code) - "it should fail to resolve the type for Deps" in { + "fail to resolve the type for Deps" in { val call = cpg.method.name("test1").call.name("foo").head call.methodFullName shouldBe s"${Defines.UnresolvedNamespace}.foo:${Defines.UnresolvedSignature}(0)" call.signature shouldBe s"${Defines.UnresolvedSignature}(0)" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LiteralTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LiteralTests.scala index ebb4c50a1907..4b95846be327 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LiteralTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LiteralTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import com.github.javaparser.ast.expr.LiteralExpr import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LiteralTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala index 086f70e3fd8d..e44ab3d70906 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala @@ -744,6 +744,17 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { @inline def constructors = cpg.typeDecl.fullName("foo.Foo.enclosingMethod.Local").method.nameExact("").sortBy(_.parameter.size) + "not have any orphan locals or parameters" in { + cpg.local.filter(_.astIn.isEmpty).l shouldBe List() + cpg.parameter.filter(_.astIn.isEmpty).l shouldBe List() + } + + "have ref edges from the outer class identifier to the parameter" in { + inside(cpg.method.nameExact("").filter(_.parameter.name.contains("ctxParam")).l) { case List(constructor) => + constructor.ast.isIdentifier.name("outerClass").refsTo.l shouldBe constructor.parameter.name("outerClass").l + } + } + "have params for captured members for both constructors" in { constructors.head.parameter.name.l shouldBe List("this", "outerClass", "outerParam") constructors.last.parameter.name.l shouldBe List("this", "ctxParam", "outerClass", "outerParam") @@ -869,7 +880,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -940,7 +951,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.foo.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1001,7 +1012,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1055,7 +1066,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1108,7 +1119,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1168,7 +1179,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1217,7 +1228,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalTests.scala index 6c2024b75004..7ba4a1df0eae 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.Local -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocalTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LombokTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LombokTests.scala index c8697f6fdef9..ca15ffbfe948 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LombokTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LombokTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.joern.x2cpg.Defines -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.javasrc2cpg.Config class LombokTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MemberTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MemberTests.scala index 7ee3291c7607..8e707b219888 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MemberTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MemberTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, ModifierTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal, Member} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NewMemberTests extends JavaSrcCode2CpgFixture { "locals shadowing members" should { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MetaDataTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MetaDataTests.scala index ea2f134d41fc..2a307d8ef928 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MetaDataTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MetaDataTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodParameterTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodParameterTests.scala index b0149ece816c..f0ce9a6bd5d4 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodParameterTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodParameterTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodParameterTests2 extends JavaSrcCode2CpgFixture { "non generic method" should { @@ -19,7 +19,7 @@ class MethodParameterTests2 extends JavaSrcCode2CpgFixture { param.order shouldBe 0 param.index shouldBe 0 param.lineNumber shouldBe Some(3) - param.columnNumber shouldBe None + param.columnNumber shouldBe Some(3) param.typeFullName shouldBe "Foo" param.evaluationStrategy shouldBe EvaluationStrategies.BY_SHARING } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodReturnTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodReturnTests.scala index 67f350880208..ba6c505cc449 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodReturnTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodReturnTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Ignore class MethodReturnTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodTests.scala index 093544e797b5..03213b5d2467 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/MethodTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/NamespaceBlockTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/NamespaceBlockTests.scala index 77fca2f6ef15..b5d45219eb0d 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/NamespaceBlockTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/NamespaceBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NamespaceBlockTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/PatternExprTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/PatternExprTests.scala new file mode 100644 index 000000000000..c20565e12e7a --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/PatternExprTests.scala @@ -0,0 +1,2807 @@ +package io.joern.javasrc2cpg.querying + +import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture +import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{ + Block, + Call, + ControlStructure, + FieldIdentifier, + Identifier, + JumpTarget, + Literal, + Local, + TypeRef +} +import io.shiftleft.semanticcpg.language.* + +class PatternExprTests extends JavaSrcCode2CpgFixture { + + "a pattern initializer in a lambda method" should { + val cpg = code(""" + |import java.util.function.Function; + | + |class Foo { + | Function test() { + | return o -> foo() instanceof String s ? s : null; + | } + |} + |""".stripMargin) + + "not create any orphan locals" in { + cpg.local.exists(_._astIn.isEmpty) shouldBe false + } + } + + "a type pattern in an expression in an explicit constructor" should { + val cpg = code(""" + |class Test { + | Object foo() { + | return "abc"; + | } + | + | public Test() { + | boolean b = foo() instanceof String s; + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("foo").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name(".*init.*").body.astChildren.l) { + case List(tmpLocal: Local, sLocal: Local, sAssign: Call, bLocal: Local, bAssign: Call) => + tmpLocal.name shouldBe "$obj0" + + bLocal.name shouldBe "b" + + // TODO should s assignment be added if it is never used + // TODO sAssign.code shouldBe "s = (String) $obj0" + + sLocal.name shouldBe "s" + + // TODO bAssign code + bAssign.methodFullName shouldBe Operators.assignment + inside(bAssign.argument.l) { case List(bIdentifier: Identifier, instanceOfCall: Call) => + bIdentifier.name shouldBe "b" + bIdentifier.typeFullName shouldBe "boolean" + bIdentifier.refsTo.l shouldBe List(bLocal) + + instanceOfCall.methodFullName shouldBe Operators.instanceOf + + inside(instanceOfCall.argument.l) { case List(tmpAssign: Call, stringType: TypeRef) => + tmpAssign.methodFullName shouldBe Operators.assignment + // TODO tmpAssign code + + inside(tmpAssign.argument.l) { case List(tmpIdentifier: Identifier, fooCall: Call) => + tmpIdentifier.name shouldBe "$obj0" + tmpIdentifier.typeFullName shouldBe "java.lang.Object" + tmpIdentifier.refsTo.l shouldBe List(tmpLocal) + + fooCall.name shouldBe "foo" + fooCall.methodFullName shouldBe "Test.foo:java.lang.Object()" + } + } + } + } + } + } + + "a pattern matching instanceof in a field initializer" should { + val cpg = code(""" + |import foo.Foo; + | + |class Test { + | public int x = Foo.FOO instanceof String s ? s.length() : -1; + |} + |""".stripMargin) + .moreCode(""" + |package foo; + | + |public class Foo { + | public Object FOO = "abc"; + |} + |""".stripMargin) + "parse" in { + cpg.call.name("length").nonEmpty shouldBe true + } + + "add the local and initialiser for the pattern variable to the method" in { + inside(cpg.typeDecl.name("Test").method.nameExact("").body.astChildren.l) { + case List(sLocal: Local, xAssign: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + + xAssign.methodFullName shouldBe Operators.assignment + + inside(xAssign.argument.l) { case List(xFieldAccess: Call, conditionalExpr: Call) => + xFieldAccess.methodFullName shouldBe Operators.fieldAccess + // TODO xFieldAccess test + + conditionalExpr.methodFullName shouldBe Operators.conditional + conditionalExpr.typeFullName shouldBe "int" + + inside(conditionalExpr.argument.l) { case List(instanceOfCall: Call, lengthCall: Call, minusCall: Call) => + instanceOfCall.methodFullName shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "Foo.FOO instanceof String" + + inside(instanceOfCall.argument.l) { case List(fooFieldAccess: Call, stringType: TypeRef) => + fooFieldAccess.code shouldBe "Foo.FOO" + + // TODO: Fix static field access arguments + // inside(fooFieldAccess.argument.l) { + // case List(fooType: TypeRef, fooFieldName: FieldIdentifier) => + // fooType.typeFullName shouldBe "foo.Foo" + + // fooFieldName.canonicalName shouldBe "FOO" + // } + + stringType.typeFullName shouldBe "java.lang.String" + } + + lengthCall.methodFullName shouldBe "java.lang.String.length:int()" + inside(lengthCall.argument.l) { case List(sAssign: Call) => + sAssign.name shouldBe Operators.assignment + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, oCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + // TODO oCast.code shouldBe "(String) $obj0" + } + } + + minusCall.methodFullName shouldBe Operators.minus + } + } + } + } + } + + "a pattern matching instanceof in a static field initializer" should { + val cpg = code(""" + |import foo.Foo; + | + |class Test { + | public static int x = Foo.FOO instanceof String s ? s.length() : -1; + |} + |""".stripMargin) + .moreCode(""" + |package foo; + | + |public class Foo { + | public Object FOO = "abc"; + |} + |""".stripMargin) + "parse" in { + cpg.call.name("length").nonEmpty shouldBe true + } + + "add the local and initialiser for the pattern variable to the method" in { + inside(cpg.typeDecl.name("Test").method.nameExact("").body.astChildren.l) { + case List(sLocal: Local, xAssign: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + + xAssign.methodFullName shouldBe Operators.assignment + // TODO xAssign code + + inside(xAssign.argument.l) { case List(xFieldAccess: Call, conditionalExpr: Call) => + xFieldAccess.methodFullName shouldBe Operators.fieldAccess + // TODO xFieldAccess test + + conditionalExpr.methodFullName shouldBe Operators.conditional + conditionalExpr.typeFullName shouldBe "int" + + inside(conditionalExpr.argument.l) { case List(instanceOfCall: Call, lengthCall: Call, minusCall: Call) => + instanceOfCall.methodFullName shouldBe Operators.instanceOf + inside(instanceOfCall.argument.l) { case List(fooFieldAccess: Call, stringType: TypeRef) => + fooFieldAccess.code shouldBe "Foo.FOO" + + stringType.typeFullName shouldBe "java.lang.String" + } + + lengthCall.methodFullName shouldBe "java.lang.String.length:int()" + inside(lengthCall.argument.l) { case List(sAssign: Call) => + sAssign.name shouldBe Operators.assignment + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, oCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + // TODO oCast.code shouldBe "(String) $obj0" + } + } + + minusCall.methodFullName shouldBe Operators.minus + } + } + } + } + } + + "a pattern matching instanceof with a call lhs" should { + val cpg = code(""" + |class Test { + | static String foo() { + | return "Hello, world!"; + | } + | + | static void sink(String s) { /* Do nothing */ } + | + | void test(Object o) { + | if (foo() instanceof String s && s.isEmpty()) { + | sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").size shouldBe 1 + } + + "add a tmp local for the foo call to the start of the method" in { + inside(cpg.method.name("test").body.astChildren.l) { case (tmpLocal: Local) :: _ => + tmpLocal.name shouldBe "$obj0" + tmpLocal.code shouldBe "$obj0" + tmpLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create an assignment for the temporary local as the first instanceof argument" in { + inside(cpg.call.nameExact(Operators.instanceOf).argument.head) { case assignment: Call => + assignment.name shouldBe Operators.assignment + assignment.typeFullName shouldBe "java.lang.String" + assignment.code shouldBe "$obj0 = foo()" + + inside(assignment.argument.l) { case List(tmpIdentifier: Identifier, fooCall: Call) => + tmpIdentifier.name shouldBe "$obj0" + tmpIdentifier.code shouldBe "$obj0" + tmpIdentifier.typeFullName shouldBe "java.lang.String" + tmpIdentifier.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + + fooCall.name shouldBe "foo" + fooCall.methodFullName shouldBe "Test.foo:java.lang.String()" + fooCall.typeFullName shouldBe "java.lang.String" + fooCall.code shouldBe "foo()" + } + } + } + } + + "patterns in binary expressions" when { + "a variable is introduced to the RHS of an && expression" should { + + val cpg = code(""" + |class Test { + | void test(Object o) { + | if (o instanceof String s && s.isEmpty()) { + | System.out.println(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("isEmpty").nonEmpty shouldBe true + } + + "not have any nodes with multiple AST parents" in { + cpg.astNode.filter(_._astIn.size > 1).l shouldBe Nil + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { case List(sLocal: Local, ifStmt: ControlStructure) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + inside(ifStmt.condition.l) { case List(andCall: Call) => + andCall.name shouldBe Operators.logicalAnd + andCall.typeFullName shouldBe TypeConstants.Boolean + // TODO fix code + // andCall.code shouldBe "o instanceof String s && (s = (String) o).isEmpty()" + + inside(andCall.argument.l) { case List(instanceOfCall: Call, isEmptyCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + + inside(instanceOfCall.argument.l) { case List(oIdentifier: Identifier, stringType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + + stringType.typeFullName shouldBe "java.lang.String" + } + + isEmptyCall.name shouldBe "isEmpty" + isEmptyCall.methodFullName shouldBe "java.lang.String.isEmpty:boolean()" + // TODO Fix code + // isEmptyCall.code shouldBe "(s = (String) o).isEmpty()" + + inside(isEmptyCall.argument.l) { case List(sAssignment: Call) => + sAssignment.name shouldBe Operators.assignment + // TODO Fix code + // sAssignment.code shouldBe "s = (String) o" + sAssignment.typeFullName shouldBe "java.lang.String" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, castCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.method.name("test").local.name("s").l + + castCall.name shouldBe Operators.cast + castCall.methodFullName shouldBe Operators.cast + castCall.code shouldBe "(String) o" + + inside(castCall.argument.l) { case List(innerStringType: TypeRef, innerOIdentifier: Identifier) => + innerStringType.typeFullName shouldBe "java.lang.String" + + innerOIdentifier.name shouldBe "o" + innerOIdentifier.typeFullName shouldBe "java.lang.Object" + innerOIdentifier.refsTo.l shouldBe cpg.method.name("test").parameter.name("o").l + } + } + } + } + } + } + } + } + + "a variable is introduced to the RHS of an || expression" should { + val cpg = code(""" + |class Test { + | void test(Object o) { + | if (!(o instanceof String s) || s.isEmpty()) { + | System.out.println("no input found"); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("isEmpty").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { case List(sLocal: Local, ifStmt: ControlStructure) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + inside(ifStmt.condition.l) { case List(orCall: Call) => + orCall.name shouldBe Operators.logicalOr + orCall.typeFullName shouldBe TypeConstants.Boolean + + // TODO fix code + // orCall.code shouldBe "!(o instanceof String s) || (s = (String) o).isEmpty()" + + inside(orCall.argument.l) { case List(notCall: Call, isEmptyCall: Call) => + notCall.code shouldBe "!(o instanceof String s)" + inside(notCall.argument.l) { case List(instanceOfCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + + inside(instanceOfCall.argument.l) { case List(oIdentifier: Identifier, stringType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + + stringType.typeFullName shouldBe "java.lang.String" + } + } + + isEmptyCall.name shouldBe "isEmpty" + isEmptyCall.methodFullName shouldBe "java.lang.String.isEmpty:boolean()" + // TODO fix code: is currently s = (String) o.isEmpty() + // isEmptyCall.code shouldBe "(s = (String) o).isEmpty()" + + inside(isEmptyCall.argument.l) { case List(sAssignment: Call) => + sAssignment.name shouldBe Operators.assignment + sAssignment.code shouldBe "s = (String) o" + sAssignment.typeFullName shouldBe "java.lang.String" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, castCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.method.name("test").local.name("s").l + + castCall.name shouldBe Operators.cast + castCall.methodFullName shouldBe Operators.cast + castCall.code shouldBe "(String) o" + + inside(castCall.argument.l) { case List(innerStringType: TypeRef, innerOIdentifier: Identifier) => + innerStringType.typeFullName shouldBe "java.lang.String" + + innerOIdentifier.name shouldBe "o" + innerOIdentifier.typeFullName shouldBe "java.lang.Object" + innerOIdentifier.refsTo.l shouldBe cpg.method.name("test").parameter.name("o").l + } + } + } + } + } + } + } + } + + "a variable is introduced to the RHS of an && expression, mutated and introduced to the body of an if" should { + val cpg = code(""" + |class Test { + | static void test(Object o) { + | if (o instanceof String value && (value = "Foo").isEmpty()) { + | System.out.println(value); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("isEmpty").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { case List(valueLocal: Local, ifStmt: ControlStructure) => + valueLocal.name shouldBe "value" + valueLocal.code shouldBe "String value" + valueLocal.typeFullName shouldBe "java.lang.String" + + // TODO Fix code + // ifStmt.code shouldBe "if (o instanceof String value && ((value = (String) o) = \"Foo\").isEmpty())" + + inside(ifStmt.condition.l) { case List(andCall: Call) => + andCall.name shouldBe Operators.logicalAnd + andCall.methodFullName shouldBe Operators.logicalAnd + // TODO Test code + + inside(andCall.astChildren.l) { case List(instanceOfCall: Call, isEmptyCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + + isEmptyCall.name shouldBe "isEmpty" + isEmptyCall.methodFullName shouldBe "java.lang.String.isEmpty:boolean()" + + inside(isEmptyCall.argument.l) { case List(fooAssignment: Call) => + fooAssignment.name shouldBe Operators.assignment + // TODO Test code + + inside(fooAssignment.argument.l) { case List(valueAssign: Call, fooLiteral: Literal) => + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.code shouldBe "value = (String) o" + + inside(valueAssign.argument.l) { case List(valueIdentifier: Identifier, oCast: Call) => + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.code shouldBe "value" + valueIdentifier.refsTo.l shouldBe List(valueLocal) + + oCast.code shouldBe "(String) o" + } + } + } + } + } + } + } + } + + "a variable is introduced to the RHS of an || expression, mutated and introduced by an if" should { + val cpg = code(""" + |class Test { + | static void test(Object o) { + | if (!(o instanceof String value) || (value = "Foo").isEmpty()) { + | return; + | } + | System.out.println(value); + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("isEmpty").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { + case List(valueLocal: Local, ifStmt: ControlStructure, printCall: Call) => + valueLocal.name shouldBe "value" + valueLocal.code shouldBe "String value" + valueLocal.typeFullName shouldBe "java.lang.String" + + // TODO Fix code + // ifStmt.code shouldBe "if (o instanceof String value && ((value = (String) o) = \"Foo\").isEmpty())" + + inside(ifStmt.condition.l) { case List(orCall: Call) => + orCall.name shouldBe Operators.logicalOr + orCall.methodFullName shouldBe Operators.logicalOr + // TODO Test code + + inside(orCall.astChildren.l) { case List(notCall: Call, isEmptyCall: Call) => + notCall.methodFullName shouldBe Operators.logicalNot + // TODO Test code + + inside(notCall.argument.l) { case List(instanceOfCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + } + + isEmptyCall.name shouldBe "isEmpty" + isEmptyCall.methodFullName shouldBe "java.lang.String.isEmpty:boolean()" + + inside(isEmptyCall.argument.l) { case List(fooAssignment: Call) => + fooAssignment.name shouldBe Operators.assignment + // TODO Test code + + inside(fooAssignment.argument.l) { case List(valueAssign: Call, fooLiteral: Literal) => + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.code shouldBe "value = (String) o" + + inside(valueAssign.argument.l) { case List(valueIdentifier: Identifier, oCast: Call) => + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.code shouldBe "value" + valueIdentifier.refsTo.l shouldBe List(valueLocal) + + oCast.code shouldBe "(String) o" + } + } + } + } + } + + printCall.name shouldBe "println" + inside(printCall.argument.l) { case List(systemOutFieldAccess: Call, valueIdentifier: Identifier) => + systemOutFieldAccess.code shouldBe "System.out" + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.code shouldBe "value" + valueIdentifier.refsTo.l shouldBe List(valueLocal) + } + } + } + } + } + + "patterns in ternary expressions" when { + "a variable is introduced to the then expression" should { + val cpg = code(""" + |class Test { + | void test(Object o) { + | int x = o instanceof String s ? s.length() : -1; + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("length").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { case List(sLocal: Local, xLocal: Local, xAssign: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + + xLocal.name shouldBe "x" + xLocal.typeFullName shouldBe "int" + + xAssign.methodFullName shouldBe Operators.assignment + + inside(xAssign.argument.l) { case List(xIdentifier: Identifier, ternaryExpr: Call) => + xIdentifier.name shouldBe "x" + + ternaryExpr.methodFullName shouldBe Operators.conditional + inside(ternaryExpr.argument.l) { case List(instanceOfCall: Call, lengthCall: Call, minusCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + + // TODO Test code + lengthCall.name shouldBe "length" + lengthCall.methodFullName shouldBe "java.lang.String.length:int()" + inside(lengthCall.argument.l) { case List(sAssign: Call) => + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = (String) o" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, oCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + oCast.name shouldBe Operators.cast + oCast.code shouldBe "(String) o" + } + } + + minusCall.name shouldBe Operators.minus + } + } + } + } + } + + "a variable is introduced to the else expression" should { + val cpg = code(""" + |class Test { + | void test(Object o) { + | int x = !(o instanceof String s) ? -1: s.length(); + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("length").nonEmpty shouldBe true + } + + "be represented correctly" in { + inside(cpg.method.name("test").body.astChildren.l) { case List(sLocal: Local, xLocal: Local, xAssign: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + + xLocal.name shouldBe "x" + xLocal.typeFullName shouldBe "int" + + xAssign.methodFullName shouldBe Operators.assignment + + inside(xAssign.argument.l) { case List(xIdentifier: Identifier, ternaryExpr: Call) => + xIdentifier.name shouldBe "x" + + ternaryExpr.methodFullName shouldBe Operators.conditional + inside(ternaryExpr.argument.l) { case List(notCall: Call, minusCall: Call, lengthCall: Call) => + notCall.methodFullName shouldBe Operators.logicalNot + + inside(notCall.argument.l) { case List(instanceOfCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.code shouldBe "o instanceof String" + } + + // TODO Test code + lengthCall.name shouldBe "length" + lengthCall.methodFullName shouldBe "java.lang.String.length:int()" + inside(lengthCall.argument.l) { case List(sAssign: Call) => + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = (String) o" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, oCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + oCast.name shouldBe Operators.cast + oCast.code shouldBe "(String) o" + } + } + + minusCall.name shouldBe Operators.minus + } + } + } + } + } + } + + "patterns in if statements" when { + "a variable is introduced to the then block" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | if (o instanceof String s) { + | sink(s); + | } + | } + | static void sink(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local in the then block" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.l) { + case List(_: Call, thenBlock: Block) => + thenBlock.ast.isLocal.name("s").typeFullName.l shouldBe List("java.lang.String") + } + } + + "create the s assignment in the then block" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.l) { + case List(_: Call, thenBlock: Block) => + thenBlock.ast.isCall + .nameExact(Operators.assignment) + .where(_.argument.isIdentifier.name("s")) + .code + .l shouldBe List("s = (String) o") + } + } + + "create an identifier referring to the s local as the argument for sink" in { + inside(cpg.call.name("sink").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + + "a variable is introduced to the else block" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | if (!(o instanceof String s)) { + | } else { + | sink(s); + | } + | } + | static void sink(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local in the else block" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.l) { + case List(_: Call, _: Block, elseStructure: ControlStructure) => + inside(elseStructure.astChildren.l) { case List(elseBlock: Block) => + elseBlock.ast.isLocal.name("s").typeFullName.l shouldBe List("java.lang.String") + } + } + } + + "create the s assignment in the then block" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.l) { + case List(_: Call, _: Block, elseStructure: ControlStructure) => + inside(elseStructure.astChildren.l) { case List(elseBlock: Block) => + } + } + } + + "create an identifier referring to the s local as the argument for sink" in { + inside(cpg.call.name("sink").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + + "a variable is introduced to the surrounding scope" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | if (!(o instanceof String s)) { + | return; + | } + | sink(s); + | } + | static void sink(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local after the if statement" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create the s assignment before the if statement" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, _: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + + "have an s identifier as the sink argument with refs to the s local" in { + inside(cpg.call.name("sink").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + + "a variable is introduced to the else block and surrounding scope" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | if (!(o instanceof String s)) { + | sink1(s); + | return; + | } else { + | sink2(s); + | } + | sink3(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + | static void sink3(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local before the if statement" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(sLocal: Local, _: Call, _: ControlStructure, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create the s assignment before the if statement" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: Local, assignment: Call, _: ControlStructure, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + + "have an s field access as the sink1 argument" in { + inside(cpg.call.name("sink1").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + + "have an s identifier as the sink2 argument with refs to the s local" in { + inside(cpg.call.name("sink2").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + + "have an s identifier as the sink3 argument with refs to the s local" in { + inside(cpg.call.name("sink3").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + } + + "patterns in while statements" when { + "a variable is introduced to the body" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | while (o instanceof String s) { + | sink1(s); + | } + | sink2(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local in the while body" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).astChildren.l) { + case List(_: Call, body: Block) => + inside(body.astChildren.l) { case List(sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + } + + "create the s assignment in the body" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).astChildren.l) { + case List(_: Call, body: Block) => + inside(body.astChildren.l) { case List(_: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + } + + "have the argument of sink1 be an s identifier with refs to the s local" in { + inside(cpg.call.name("sink1").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + + "have the argument of sink2 be a field access for s" in { + inside(cpg.call.name("sink2").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + } + + "a variable is introduced by the while" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | while (!(o instanceof String s)) { + | sink1(s); + | } + | sink2(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local after the while loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create the s assignment after the while loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, _: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + + "have the argument of sink1 be a field access for s" in { + inside(cpg.call.name("sink1").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + + "have the argument of sink2 be an s identifier with refs to the s local" in { + inside(cpg.call.name("sink2").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + } + + "patterns in do statements" when { + "a variable is introduced by the do" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | do { sink1(s); } while (!(o instanceof String s)); + | sink2(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local after the do loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create the s assignment after the do loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, _: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + + "have the argument of sink1 be a field access for s" in { + inside(cpg.call.name("sink1").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + + "have the argument of sink2 be an s identifier with refs to the s local" in { + inside(cpg.call.name("sink2").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + + } + } + + "patterns in for statements" when { + "a variable is introduced to the for update" should { + val cpg = code(""" + |class Foo { + | void foo(Object o) { + | for(int i = 0; o instanceof String s && i < 42; i += s.length()) { + | System.out.println(i); + | } + | } + |} + |""".stripMargin) + + "be represented correctly" in { + inside(cpg.method.name("foo").body.astChildren.l) { case List(sLocal: Local, forStmt: ControlStructure) => + sLocal.name shouldBe "s" + sLocal.code shouldBe "String s" + sLocal.typeFullName shouldBe "java.lang.String" + + forStmt.controlStructureType shouldBe ControlStructureTypes.FOR + inside(forStmt.astChildren.l) { + case List(iLocal: Local, iAssign: Call, condition: Call, update: Call, body: Block) => + iLocal.name shouldBe "i" + + iAssign.methodFullName shouldBe Operators.assignment + iAssign.code shouldBe "int i = 0" + + condition.methodFullName shouldBe Operators.logicalAnd + // TODO Check LHS arg + + update.methodFullName shouldBe Operators.assignmentPlus + inside(update.argument.l) { case List(iIdentifier: Identifier, lengthCall: Call) => + iIdentifier.name shouldBe "i" + iIdentifier.refsTo.l shouldBe List(iLocal) + + lengthCall.name shouldBe "length" + lengthCall.methodFullName shouldBe "java.lang.String.length:int()" + // TODO Test code + // TODO This representation is technically not correct. It's possible to + inside(lengthCall.argument.l) { case List(sAssignment: Call) => + sAssignment.methodFullName shouldBe Operators.assignment + sAssignment.code shouldBe "s = (String) o" + } + } + + inside(body.astChildren.l) { case List(printlnCall: Call) => + printlnCall.code shouldBe "System.out.println(i)" + } + + } + } + } + } + + "a variable is introduced to the for body" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | for (; o instanceof String s;) { + | sink1(s); + | } + | sink2(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + |} + |""".stripMargin) + + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local in the for body" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.FOR).astChildren.l) { + case List(_: Call, body: Block) => + inside(body.astChildren.l) { case List(sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + } + + "create the s assignment in the for body" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.FOR).astChildren.l) { + case List(_: Call, body: Block) => + inside(body.astChildren.l) { case List(_: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + } + + "have the argument of sink1 be an s identifier with refs to the s local" in { + inside(cpg.call.name("sink1").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + + "have the argument of sink2 be a field access for s" in { + inside(cpg.call.name("sink2").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + } + + "a variable is introduced by the for" should { + val cpg = code(""" + |class Foo { + | Integer s; + | void foo(Object o) { + | for (; !(o instanceof String s);) { + | sink1(s); + | } + | sink2(s); + | } + | static void sink1(Object o) {} + | static void sink2(Object o) {} + |} + |""".stripMargin) + "parse" in { + cpg.identifier.name("s").nonEmpty shouldBe true + } + + "create the s local after the for loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, sLocal: Local, _: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + } + } + + "create the s assignment after the for loop" in { + inside(cpg.method.name("foo").body.astChildren.l) { + case List(_: ControlStructure, _: Local, assignment: Call, _: Call) => + assignment.name shouldBe Operators.assignment + assignment.code shouldBe "s = (String) o" + } + } + + "have the argument of sink1 be a field access for s" in { + inside(cpg.call.name("sink1").argument.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.s" + fieldAccess.typeFullName shouldBe "java.lang.Integer" + } + } + + "have the argument of sink2 be an s identifier with refs to the s local" in { + inside(cpg.call.name("sink2").argument.l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe cpg.local.name("s").l + } + } + } + } + + "resolved patterns in instanceof expressions" when { + "a type pattern is matched" should { + val cpg = code(""" + |class Foo { + | void foo(Object o) { + | if (o instanceof String s) { + | sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(instanceOfCall: Call) => + instanceOfCall.name shouldBe Operators.instanceOf + instanceOfCall.typeFullName shouldBe "boolean" + instanceOfCall.code shouldBe "o instanceof String" + + inside(instanceOfCall.argument.l) { case List(oIdentifier: Identifier, stringType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.code shouldBe "s = (String) o" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, castExpr: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l should contain theSameElementsAs List(sLocal) + + castExpr.name shouldBe Operators.cast + castExpr.methodFullName shouldBe Operators.cast + castExpr.typeFullName shouldBe "java.lang.String" + castExpr.code shouldBe "(String) o" + + inside(castExpr.argument.l) { case List(stringType: TypeRef, oIdentifier: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l should contain theSameElementsAs cpg.method.name("foo").parameter.name("o").l + } + } + } + } + } + + "a non-generic, non-nested record pattern is matched" should { + val cpg = code(""" + |package box; + | + |record Box(String value) {} + | + |class Foo { + | void foo(Object o) { + | if (o instanceof Box(String s)) { + | sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(instanceOfBox: Call) => + instanceOfBox.name shouldBe Operators.instanceOf + instanceOfBox.methodFullName shouldBe Operators.instanceOf + instanceOfBox.code shouldBe "o instanceof Box" + instanceOfBox.typeFullName shouldBe "boolean" + + inside(instanceOfBox.argument.l) { case List(oIdentifier: Identifier, boxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssignment: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssignment.name shouldBe Operators.assignment + sAssignment.methodFullName shouldBe Operators.assignment + sAssignment.typeFullName shouldBe "java.lang.String" + sAssignment.code shouldBe "s = ((Box) o).value()" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, valueCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.Box.value:java.lang.String()" + valueCall.code shouldBe "((Box) o).value()" + valueCall.typeFullName shouldBe "java.lang.String" + + inside(valueCall.receiver.l) { case List(boxCast: Call) => + boxCast.name shouldBe Operators.cast + boxCast.code shouldBe "(Box) o" + boxCast.typeFullName shouldBe "box.Box" + + inside(boxCast.argument.l) { case List(boxType: TypeRef, oIdentifier: Identifier) => + boxType.typeFullName shouldBe "box.Box" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.parameter.name("o").l + } + } + } + } + } + } + + "a generic, non-nested record pattern is matched" should { + val cpg = code(""" + |package box; + | + |record Box(T value) {} + | + |class Foo { + | void foo(Object o) { + | if (o instanceof Box(String s)) { + | sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(andCall: Call) => + andCall.name shouldBe Operators.logicalAnd + andCall.methodFullName shouldBe Operators.logicalAnd + andCall.code shouldBe "(o instanceof Box) && (($obj0 = ((Box) o).value()) instanceof String)" + + inside(andCall.argument.l) { case List(instanceOfBox: Call, instanceOfString: Call) => + instanceOfBox.name shouldBe Operators.instanceOf + instanceOfBox.methodFullName shouldBe Operators.instanceOf + instanceOfBox.code shouldBe "o instanceof Box" + instanceOfBox.typeFullName shouldBe "boolean" + + inside(instanceOfBox.argument.l) { case List(oIdentifier: Identifier, boxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + } + + instanceOfString.name shouldBe Operators.instanceOf + instanceOfString.methodFullName shouldBe Operators.instanceOf + instanceOfString.code shouldBe "($obj0 = ((Box) o).value()) instanceof String" + instanceOfString.typeFullName shouldBe "boolean" + + inside(instanceOfString.argument.l) { case List(tmpAssign: Call, stringType: TypeRef) => + tmpAssign.name shouldBe Operators.assignment + tmpAssign.methodFullName shouldBe Operators.assignment + tmpAssign.code shouldBe "$obj0 = ((Box) o).value()" + tmpAssign.typeFullName shouldBe "java.lang.Object" + + inside(tmpAssign.argument.l) { case List(tmpIdentifier0: Identifier, valueCall: Call) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.Box.value:java.lang.Object()" + valueCall.code shouldBe "((Box) o).value()" + valueCall.typeFullName shouldBe "java.lang.Object" + + inside(valueCall.argument.l) { case List(castExpr: Call) => + castExpr.name shouldBe Operators.cast + castExpr.methodFullName shouldBe Operators.cast + castExpr.typeFullName shouldBe "box.Box" + inside(castExpr.argument.l) { case List(castBoxType: TypeRef, castOIdentifier: Identifier) => + castBoxType.typeFullName shouldBe "box.Box" + castBoxType.code shouldBe "Box" + + castOIdentifier.name shouldBe "o" + castOIdentifier.typeFullName shouldBe "java.lang.Object" + castOIdentifier.code shouldBe "o" + castOIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + } + } + } + + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + } + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssignment: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssignment.name shouldBe Operators.assignment + sAssignment.methodFullName shouldBe Operators.assignment + sAssignment.typeFullName shouldBe "java.lang.String" + sAssignment.code shouldBe "s = (String) $obj0" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, stringCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + stringCast.name shouldBe Operators.cast + stringCast.methodFullName shouldBe Operators.cast + stringCast.typeFullName shouldBe "java.lang.String" + stringCast.code shouldBe "(String) $obj0" + + inside(stringCast.argument.l) { case List(stringType: TypeRef, tmpIdentifier0: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + } + } + } + + "a non-generic, nested record pattern is matched" should { + + val cpg = code(""" + |package box; + | + |record PairBox(Pair value) {} + |record Pair(String first, Integer second) {} + | + |class Foo { + | void foo(Object o) { + | if (o instanceof PairBox(Pair(String s, Integer i))) { + | sink(s); + | sink(i); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(oInstanceOfPairBox: Call) => + oInstanceOfPairBox.name shouldBe Operators.instanceOf + oInstanceOfPairBox.methodFullName shouldBe Operators.instanceOf + oInstanceOfPairBox.typeFullName shouldBe "boolean" + oInstanceOfPairBox.code shouldBe "o instanceof PairBox" + + inside(oInstanceOfPairBox.argument.l) { case List(oIdentifier: Identifier, pairBoxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe oParameter + + pairBoxType.typeFullName shouldBe "box.PairBox" + pairBoxType.code shouldBe "PairBox" + } + + } + } + + "have the correct lowering for the variable assignment" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call, sSink: Call, iSink: Call) => + sLocal.name shouldBe "s" + sLocal.code shouldBe "String s" + sLocal.typeFullName shouldBe "java.lang.String" + + iLocal.name shouldBe "i" + iLocal.code shouldBe "Integer i" + iLocal.typeFullName shouldBe "java.lang.Integer" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = ($obj0 = ((PairBox) o).value()).first()" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, firstCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + firstCall.name shouldBe "first" + firstCall.methodFullName shouldBe "box.Pair.first:java.lang.String()" + firstCall.typeFullName shouldBe "java.lang.String" + firstCall.code shouldBe "($obj0 = ((PairBox) o).value()).first()" + + inside(firstCall.argument.l) { case List(tmpAssign0: Call) => + tmpAssign0.name shouldBe Operators.assignment + tmpAssign0.code shouldBe "$obj0 = ((PairBox) o).value()" + tmpAssign0.typeFullName shouldBe "box.Pair" + + inside(tmpAssign0.argument.l) { case List(tmpIdentifier0: Identifier, valueCall: Call) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "box.Pair" + tmpIdentifier0.refsTo.l shouldBe cpg.method("foo").local.nameExact("$obj0").l + + valueCall.name shouldBe "value" + valueCall.typeFullName shouldBe "box.Pair" + valueCall.methodFullName shouldBe "box.PairBox.value:box.Pair()" + + inside(valueCall.argument.l) { case List(pairBoxCast: Call) => + pairBoxCast.name shouldBe Operators.cast + pairBoxCast.code shouldBe "(PairBox) o" + pairBoxCast.typeFullName shouldBe "box.PairBox" + + inside(pairBoxCast.argument.l) { case List(pairBoxType: TypeRef, oIdentifier: Identifier) => + pairBoxType.typeFullName shouldBe "box.PairBox" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.parameter.name("o").l + } + } + } + } + } + + iAssign.name shouldBe Operators.assignment + iAssign.methodFullName shouldBe Operators.assignment + iAssign.typeFullName shouldBe "java.lang.Integer" + iAssign.code shouldBe "i = $obj0.second()" + + inside(iAssign.argument.l) { case List(iIdentifier: Identifier, secondCall: Call) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + + secondCall.name shouldBe "second" + secondCall.methodFullName shouldBe "box.Pair.second:java.lang.Integer()" + secondCall.code shouldBe "$obj0.second()" + secondCall.typeFullName shouldBe "java.lang.Integer" + + inside(secondCall.argument.l) { case List(tmpIdentifier0: Identifier) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "box.Pair" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + + inside(sSink.argument.isIdentifier.name("s").l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + } + + inside(iSink.argument.isIdentifier.name("i").l) { case List(iIdentifier: Identifier) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + } + } + } + } + + "a complex mixed record pattern" should { + val cpg = code(""" + |record A(B a0, C a1) {} + |record B(String b0) {} + |record C(D c0, F c1) {} + |record D(String d0, E d1) {} + |record E(String e0) {} + |record F(G f0) {} + |record G(String g0, T g1) {} + | + |class Test { + | void test(Object o) { + | if (o instanceof A(B(String b0), C(D(String d0, E(String e0)), F(G(String g0, Integer g1))))) { } + | } + |} + |""".stripMargin) + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + firstAnd.code shouldBe "(o instanceof A) && (($obj2 = ($obj1 = ($obj0 = ((A) o).a1()).c1().f0()).g1()) instanceof Integer)" + } + } + + "have the correct lowering for the assignments" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List( + b0Local: Local, + b0Assign: Call, + d0Local: Local, + d0Assign: Call, + e0Local: Local, + e0Assign: Call, + g0Local: Local, + g0Assign: Call, + g1Local: Local, + g1Assign: Call + ) => + b0Local.name shouldBe "b0" + d0Local.name shouldBe "d0" + e0Local.name shouldBe "e0" + g0Local.name shouldBe "g0" + g1Local.name shouldBe "g1" + + b0Assign.code shouldBe "b0 = ((A) o).a0().b0()" + d0Assign.code shouldBe "d0 = ($obj3 = $obj0.c0()).d0()" + e0Assign.code shouldBe "e0 = $obj3.d1().e0()" + g0Assign.code shouldBe "g0 = $obj1.g0()" + g1Assign.code shouldBe "g1 = (Integer) $obj2" + } + } + } + + "a mixed record pattern where nested first child and second child needs instanceof" should { + val cpg = code(""" + |record Foo(T value) {} + |record Bar(Foo left, T right) {} + | + |class Test { + | void test(Object o) { + | if (o instanceof Foo(Bar(Foo(String s), Integer i))) { } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.nameExact(Operators.instanceOf).nonEmpty shouldBe true + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + firstAnd.code shouldBe "(o instanceof Foo) && ((($obj0 = ((Foo) o).value()) instanceof Bar) && ((($obj1 = ((Bar) $obj0).left().value()) instanceof String) && (($obj2 = ((Bar) $obj0).right()) instanceof Integer)))" + } + } + + "have the correct lowering for the assignments" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call) => + sLocal.name shouldBe "s" + sAssign.code shouldBe "s = (String) $obj1" + iLocal.name shouldBe "i" + iAssign.code shouldBe "i = (Integer) $obj2" + } + } + } + + "a mixed record pattern where only the second child needs instanceof" should { + val cpg = code(""" + |record Foo(T value) {} + |record Bar(String left, T right) {} + | + |class Test { + | void test(Object o) { + | if (o instanceof Foo(Bar(String s, Integer i))) { } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.nameExact(Operators.instanceOf).nonEmpty shouldBe true + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + firstAnd.code shouldBe "(o instanceof Foo) && ((($obj0 = ((Foo) o).value()) instanceof Bar) && (($obj1 = ((Bar) $obj0).right()) instanceof Integer))" + } + } + + "have the correct lowering for the assignments" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call) => + sLocal.name shouldBe "s" + sAssign.code shouldBe "s = ((Bar) $obj0).left()" + iLocal.name shouldBe "i" + iAssign.code shouldBe "i = (Integer) $obj1" + } + } + } + + "a mixed generic record pattern is matched" should { + val cpg = code(""" + |record Foo(T foo) {} + |record Bar(Baz bar) {} + |record Baz(T baz) {} + |record Qux(String qux) {} + | + |class Test { + | void test(Object o) { + | if (o instanceof Foo(Bar(Baz(Qux(String s))))) { + | sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + } + } + } + + "a generic, nested record pattern is matched" should { + + val cpg = code(""" + |package box; + | + |record Box(T value) {} + |record Pair(U first, V second) {} + | + |class Foo { + | void foo(Object o) { + | if (o instanceof Box(Pair(String s, Integer i))) { + | sink(s); + | sink(i); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + firstAnd.methodFullName shouldBe Operators.logicalAnd + firstAnd.typeFullName shouldBe "boolean" + firstAnd.code shouldBe "(o instanceof Box) && ((($obj0 = ((Box) o).value()) instanceof Pair) && ((($obj1 = ((Pair) $obj0).first()) instanceof String) && (($obj2 = ((Pair) $obj0).second()) instanceof Integer)))" + + inside(firstAnd.argument.l) { case List(oInstanceOfBox: Call, secondAnd: Call) => + oInstanceOfBox.name shouldBe Operators.instanceOf + oInstanceOfBox.methodFullName shouldBe Operators.instanceOf + oInstanceOfBox.typeFullName shouldBe "boolean" + oInstanceOfBox.code shouldBe "o instanceof Box" + + inside(oInstanceOfBox.argument.l) { case List(oIdentifier: Identifier, boxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe oParameter + + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + } + + secondAnd.name shouldBe Operators.logicalAnd + secondAnd.methodFullName shouldBe Operators.logicalAnd + secondAnd.typeFullName shouldBe "boolean" + secondAnd.code shouldBe "(($obj0 = ((Box) o).value()) instanceof Pair) && ((($obj1 = ((Pair) $obj0).first()) instanceof String) && (($obj2 = ((Pair) $obj0).second()) instanceof Integer))" + + inside(secondAnd.argument.l) { case List(oValueInstanceOfPair: Call, thirdAnd: Call) => + oValueInstanceOfPair.name shouldBe Operators.instanceOf + oValueInstanceOfPair.methodFullName shouldBe Operators.instanceOf + oValueInstanceOfPair.typeFullName shouldBe "boolean" + oValueInstanceOfPair.code shouldBe "($obj0 = ((Box) o).value()) instanceof Pair" + + inside(oValueInstanceOfPair.argument.l) { case List(tmpAssignment: Call, pairType: TypeRef) => + inside(tmpAssignment.argument.l) { case List(tmpIdentifier0: Identifier, valueCall: Call) => + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.Box.value:java.lang.Object()" + valueCall.signature shouldBe "java.lang.Object()" + valueCall.code shouldBe "((Box) o).value()" + + inside(valueCall.argument.l) { case List(castExpr: Call) => + castExpr.name shouldBe Operators.cast + castExpr.methodFullName shouldBe Operators.cast + castExpr.typeFullName shouldBe "box.Box" + castExpr.code shouldBe "(Box) o" + + inside(castExpr.argument.l) { case List(boxType: TypeRef, oIdentifier: Identifier) => + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + + oIdentifier.name shouldBe "o" + oIdentifier.code shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe oParameter + } + } + } + + pairType.typeFullName shouldBe "box.Pair" + pairType.code shouldBe "Pair" + } + + thirdAnd.name shouldBe Operators.logicalAnd + thirdAnd.methodFullName shouldBe Operators.logicalAnd + thirdAnd.typeFullName shouldBe "boolean" + thirdAnd.code shouldBe "(($obj1 = ((Pair) $obj0).first()) instanceof String) && (($obj2 = ((Pair) $obj0).second()) instanceof Integer)" + + inside(thirdAnd.argument.l) { case List(firstInstanceOfString: Call, secondInstanceOfInteger: Call) => + firstInstanceOfString.name shouldBe Operators.instanceOf + firstInstanceOfString.methodFullName shouldBe Operators.instanceOf + firstInstanceOfString.typeFullName shouldBe "boolean" + firstInstanceOfString.code shouldBe "($obj1 = ((Pair) $obj0).first()) instanceof String" + + inside(firstInstanceOfString.argument.l) { case List(tmp1Assign: Call, stringType: TypeRef) => + tmp1Assign.name shouldBe Operators.assignment + tmp1Assign.methodFullName shouldBe Operators.assignment + tmp1Assign.code shouldBe "$obj1 = ((Pair) $obj0).first()" + tmp1Assign.typeFullName shouldBe "java.lang.Object" + + inside(tmp1Assign.argument.l) { case List(tmpIdentifier1: Identifier, firstCall: Call) => + tmpIdentifier1.name shouldBe "$obj1" + tmpIdentifier1.code shouldBe "$obj1" + tmpIdentifier1.typeFullName shouldBe "java.lang.Object" + tmpIdentifier1.refsTo.l shouldBe cpg.local.nameExact("$obj1").l + + firstCall.name shouldBe "first" + firstCall.methodFullName shouldBe "box.Pair.first:java.lang.Object()" + firstCall.typeFullName shouldBe "java.lang.Object" + firstCall.code shouldBe "((Pair) $obj0).first()" + + inside(firstCall.argument.l) { case List(pairCast: Call) => + pairCast.name shouldBe Operators.cast + pairCast.methodFullName shouldBe Operators.cast + pairCast.typeFullName shouldBe "box.Pair" + pairCast.code shouldBe "(Pair) $obj0" + + inside(pairCast.argument.l) { case List(pairType: TypeRef, tmpIdentifier0: Identifier) => + pairType.typeFullName shouldBe "box.Pair" + + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + } + + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + } + + secondInstanceOfInteger.name shouldBe Operators.instanceOf + secondInstanceOfInteger.methodFullName shouldBe Operators.instanceOf + secondInstanceOfInteger.typeFullName shouldBe "boolean" + secondInstanceOfInteger.code shouldBe "($obj2 = ((Pair) $obj0).second()) instanceof Integer" + + inside(secondInstanceOfInteger.argument.l) { case List(tmp2Assign: Call, integerType: TypeRef) => + tmp2Assign.name shouldBe Operators.assignment + tmp2Assign.methodFullName shouldBe Operators.assignment + tmp2Assign.code shouldBe "$obj2 = ((Pair) $obj0).second()" + tmp2Assign.typeFullName shouldBe "java.lang.Object" + + inside(tmp2Assign.argument.l) { case List(tmpIdentifier2: Identifier, secondCall: Call) => + tmpIdentifier2.name shouldBe "$obj2" + tmpIdentifier2.code shouldBe "$obj2" + tmpIdentifier2.typeFullName shouldBe "java.lang.Object" + tmpIdentifier2.refsTo.l shouldBe cpg.local.nameExact("$obj2").l + + secondCall.name shouldBe "second" + secondCall.methodFullName shouldBe "box.Pair.second:java.lang.Object()" + secondCall.typeFullName shouldBe "java.lang.Object" + secondCall.code shouldBe "((Pair) $obj0).second()" + + inside(secondCall.argument.l) { case List(pairCast: Call) => + pairCast.name shouldBe Operators.cast + pairCast.methodFullName shouldBe Operators.cast + pairCast.typeFullName shouldBe "box.Pair" + pairCast.code shouldBe "(Pair) $obj0" + + inside(pairCast.argument.l) { case List(pairType: TypeRef, tmpIdentifier0: Identifier) => + pairType.typeFullName shouldBe "box.Pair" + pairType.code shouldBe "Pair" + + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + } + integerType.typeFullName shouldBe "java.lang.Integer" + integerType.code shouldBe "Integer" + } + } + } + } + } + } + + "have the correct lowering for the variable assignment" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call, sSink: Call, iSink: Call) => + sLocal.name shouldBe "s" + sLocal.code shouldBe "String s" + sLocal.typeFullName shouldBe "java.lang.String" + + iLocal.name shouldBe "i" + iLocal.code shouldBe "Integer i" + iLocal.typeFullName shouldBe "java.lang.Integer" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = (String) $obj1" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, stringCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + stringCast.name shouldBe Operators.cast + stringCast.methodFullName shouldBe Operators.cast + stringCast.typeFullName shouldBe "java.lang.String" + stringCast.code shouldBe "(String) $obj1" + + inside(stringCast.argument.l) { case List(stringType: TypeRef, tmpIdentifier1: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + + tmpIdentifier1.name shouldBe "$obj1" + tmpIdentifier1.code shouldBe "$obj1" + tmpIdentifier1.typeFullName shouldBe "java.lang.Object" + tmpIdentifier1.refsTo.l shouldBe cpg.local.nameExact("$obj1").l + } + + iAssign.name shouldBe Operators.assignment + iAssign.methodFullName shouldBe Operators.assignment + iAssign.typeFullName shouldBe "java.lang.Integer" + iAssign.code shouldBe "i = (Integer) $obj2" + + inside(iAssign.argument.l) { case List(iIdentifier: Identifier, integerCast: Call) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + + integerCast.name shouldBe Operators.cast + integerCast.methodFullName shouldBe Operators.cast + integerCast.typeFullName shouldBe "java.lang.Integer" + integerCast.code shouldBe "(Integer) $obj2" + + inside(integerCast.argument.l) { case List(integerType: TypeRef, tmpIdentifier2: Identifier) => + integerType.typeFullName shouldBe "java.lang.Integer" + integerType.code shouldBe "Integer" + + tmpIdentifier2.name shouldBe "$obj2" + tmpIdentifier2.code shouldBe "$obj2" + tmpIdentifier2.typeFullName shouldBe "java.lang.Object" + tmpIdentifier2.refsTo.l shouldBe cpg.local.nameExact("$obj2").l + } + } + } + } + } + } + } + + "resolved patterns in switch expressions" when { + "a type pattern is matched" should { + val cpg = code(""" + |package box; + | + |class Foo { + | void foo(Object o) { + | switch (o) { + | case String s -> sink(s); + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the variable assignment" in { + // TODO Should this be MATCH? + inside( + cpg.controlStructure.controlStructureType(ControlStructureTypes.SWITCH).astChildren.isBlock.astChildren.l + ) { case List(_: JumpTarget, instanceCheck: ControlStructure) => + inside(instanceCheck.astChildren.collectAll[Call].argument.l) { + case List(oIdentifier: Identifier, stringType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.code shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + + stringType.typeFullName shouldBe "java.lang.String" + } + instanceCheck.code shouldBe "if (o instanceof String)" + + inside(instanceCheck.astChildren.l) { case List(instanceOfCall: Call, statementsBlock: Block) => + instanceOfCall.code shouldBe "o instanceof String" + inside(statementsBlock.astChildren.l) { case List(sLocal: Local, sAssign: Call, sinkCall: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.code shouldBe "s = (String) o" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, castExpr: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l should contain theSameElementsAs List(sLocal) + + castExpr.name shouldBe Operators.cast + castExpr.methodFullName shouldBe Operators.cast + castExpr.typeFullName shouldBe "java.lang.String" + castExpr.code shouldBe "(String) o" + + inside(castExpr.argument.l) { case List(stringType: TypeRef, oIdentifier: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l should contain theSameElementsAs cpg.method.name("foo").parameter.name("o").l + } + } + } + } + } + } + } + + "a non-generic, non-nested record pattern is matched" should { + val cpg = code(""" + |package box; + | + |record Box(String value) {} + | + |class Foo { + | void foo(Object o) { + | switch (o) { + | case Box(String s) -> sink(s); + | default -> {} + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(instanceOfBox: Call) => + instanceOfBox.name shouldBe Operators.instanceOf + instanceOfBox.methodFullName shouldBe Operators.instanceOf + instanceOfBox.code shouldBe "o instanceof Box" + instanceOfBox.typeFullName shouldBe "boolean" + + inside(instanceOfBox.argument.l) { case List(oIdentifier: Identifier, boxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + } + + } + } + + "have the correct lowering for the variable assignment" in { + inside( + cpg.controlStructure + .controlStructureType(ControlStructureTypes.IF) + .astChildren + .isBlock + .astChildren + .l + ) { case List(sLocal: Local, sAssignment: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssignment.name shouldBe Operators.assignment + sAssignment.methodFullName shouldBe Operators.assignment + sAssignment.typeFullName shouldBe "java.lang.String" + sAssignment.code shouldBe "s = ((Box) o).value()" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, valueCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.Box.value:java.lang.String()" + valueCall.typeFullName shouldBe "java.lang.String" + valueCall.code shouldBe "((Box) o).value()" + + inside(valueCall.argument.l) { case List(oCast: Call) => + oCast.name shouldBe Operators.cast + oCast.code shouldBe "(Box) o" + oCast.typeFullName shouldBe "box.Box" + oCast.lineNumber shouldBe Some(9) + oCast.columnNumber shouldBe Some(12) + + inside(oCast.argument.l) { case List(boxType: TypeRef, oIdentifier: Identifier) => + boxType.typeFullName shouldBe "box.Box" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + } + } + } + } + } + } + + "a generic, non-nested record pattern is matched" should { + val cpg = code(""" + |package box; + | + |record Box(T value) {} + | + |class Foo { + | void foo(Object o) { + | switch (o) { + | case Box(String s) -> sink(s); + | default -> {} + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(andCall: Call) => + andCall.name shouldBe Operators.logicalAnd + andCall.methodFullName shouldBe Operators.logicalAnd + andCall.code shouldBe "(o instanceof Box) && (($obj0 = ((Box) o).value()) instanceof String)" + + inside(andCall.argument.l) { case List(instanceOfBox: Call, instanceOfString: Call) => + instanceOfBox.name shouldBe Operators.instanceOf + instanceOfBox.methodFullName shouldBe Operators.instanceOf + instanceOfBox.code shouldBe "o instanceof Box" + instanceOfBox.typeFullName shouldBe "boolean" + + inside(instanceOfBox.argument.l) { case List(oIdentifier: Identifier, boxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + boxType.typeFullName shouldBe "box.Box" + boxType.code shouldBe "Box" + } + + instanceOfString.name shouldBe Operators.instanceOf + instanceOfString.methodFullName shouldBe Operators.instanceOf + instanceOfString.code shouldBe "($obj0 = ((Box) o).value()) instanceof String" + instanceOfString.typeFullName shouldBe "boolean" + + inside(instanceOfString.argument.l) { case List(tmpAssign: Call, stringType: TypeRef) => + tmpAssign.name shouldBe Operators.assignment + tmpAssign.methodFullName shouldBe Operators.assignment + tmpAssign.code shouldBe "$obj0 = ((Box) o).value()" + tmpAssign.typeFullName shouldBe "java.lang.Object" + + inside(tmpAssign.argument.l) { case List(tmpIdentifier0: Identifier, valueCall: Call) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.Box.value:java.lang.Object()" + valueCall.code shouldBe "((Box) o).value()" + valueCall.typeFullName shouldBe "java.lang.Object" + + inside(valueCall.argument.l) { case List(castExpr: Call) => + castExpr.name shouldBe Operators.cast + castExpr.methodFullName shouldBe Operators.cast + castExpr.typeFullName shouldBe "box.Box" + inside(castExpr.argument.l) { case List(castBoxType: TypeRef, castOIdentifier: Identifier) => + castBoxType.typeFullName shouldBe "box.Box" + castBoxType.code shouldBe "Box" + + castOIdentifier.name shouldBe "o" + castOIdentifier.typeFullName shouldBe "java.lang.Object" + castOIdentifier.code shouldBe "o" + castOIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + } + } + } + + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + } + } + } + } + + "have the correct lowering for the variable assignment" in { + inside( + cpg.controlStructure + .controlStructureType(ControlStructureTypes.IF) + .astChildren + .isBlock + .astChildren + .l + ) { case List(sLocal: Local, sAssignment: Call, _: Call) => + sLocal.name shouldBe "s" + sLocal.typeFullName shouldBe "java.lang.String" + sLocal.code shouldBe "String s" + + sAssignment.name shouldBe Operators.assignment + sAssignment.methodFullName shouldBe Operators.assignment + sAssignment.typeFullName shouldBe "java.lang.String" + sAssignment.code shouldBe "s = (String) $obj0" + + inside(sAssignment.argument.l) { case List(sIdentifier: Identifier, stringCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.code shouldBe "s" + sIdentifier.refsTo.l shouldBe List(sLocal) + + stringCast.name shouldBe Operators.cast + stringCast.methodFullName shouldBe Operators.cast + stringCast.typeFullName shouldBe "java.lang.String" + stringCast.code shouldBe "(String) $obj0" + + inside(stringCast.argument.l) { case List(stringType: TypeRef, tmpIdentifier0: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "java.lang.Object" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + } + } + } + + "a non-generic, nested record pattern is matched" should { + + val cpg = code(""" + |package box; + | + |record PairBox(Pair value) {} + |record Pair(String first, Integer second) {} + | + |class Foo { + | void foo(Object o) { + | switch (o) { + | case PairBox(Pair(String s, Integer i)) -> { sink(s); sink(i); } + | default -> {} + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the type check" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(oInstanceOfPairBox: Call) => + oInstanceOfPairBox.name shouldBe Operators.instanceOf + oInstanceOfPairBox.methodFullName shouldBe Operators.instanceOf + oInstanceOfPairBox.typeFullName shouldBe "boolean" + oInstanceOfPairBox.code shouldBe "o instanceof PairBox" + + inside(oInstanceOfPairBox.argument.l) { case List(oIdentifier: Identifier, pairBoxType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.code shouldBe "o" + oIdentifier.refsTo.l shouldBe oParameter + + pairBoxType.typeFullName shouldBe "box.PairBox" + pairBoxType.code shouldBe "PairBox" + } + } + } + + "have the correct lowering for the variable assignment" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call, sinkS: Call, sinkI: Call) => + sLocal.name shouldBe "s" + sLocal.code shouldBe "String s" + sLocal.typeFullName shouldBe "java.lang.String" + + iLocal.name shouldBe "i" + iLocal.code shouldBe "Integer i" + iLocal.typeFullName shouldBe "java.lang.Integer" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = ($obj0 = ((PairBox) o).value()).first()" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, firstCall: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + firstCall.name shouldBe "first" + firstCall.methodFullName shouldBe "box.Pair.first:java.lang.String()" + firstCall.code shouldBe "($obj0 = ((PairBox) o).value()).first()" + firstCall.typeFullName shouldBe "java.lang.String" + + inside(firstCall.argument.l) { case List(tmpAssign0: Call) => + tmpAssign0.name shouldBe Operators.assignment + tmpAssign0.code shouldBe "$obj0 = ((PairBox) o).value()" + tmpAssign0.typeFullName shouldBe "box.Pair" + + inside(tmpAssign0.argument.l) { case List(tmpIdentifier0: Identifier, valueCall: Call) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "box.Pair" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + + valueCall.name shouldBe "value" + valueCall.methodFullName shouldBe "box.PairBox.value:box.Pair()" + valueCall.code shouldBe "((PairBox) o).value()" + + inside(valueCall.argument.l) { case List(pairBoxCast: Call) => + pairBoxCast.name shouldBe Operators.cast + pairBoxCast.code shouldBe "(PairBox) o" + pairBoxCast.typeFullName shouldBe "box.PairBox" + + inside(pairBoxCast.argument.l) { case List(pairBoxType: TypeRef, oIdentifier: Identifier) => + pairBoxType.typeFullName shouldBe "box.PairBox" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.parameter.name("o").l + } + } + + } + } + } + + iAssign.name shouldBe Operators.assignment + iAssign.methodFullName shouldBe Operators.assignment + iAssign.typeFullName shouldBe "java.lang.Integer" + iAssign.code shouldBe "i = $obj0.second()" + + inside(iAssign.argument.l) { case List(iIdentifier: Identifier, secondCall: Call) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + + secondCall.name shouldBe "second" + secondCall.methodFullName shouldBe "box.Pair.second:java.lang.Integer()" + secondCall.code shouldBe "$obj0.second()" + + inside(secondCall.argument.l) { case List(tmpIdentifier0: Identifier) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "box.Pair" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + } + } + + inside(sinkS.argument.isIdentifier.name("s").l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + } + + inside(sinkI.argument.isIdentifier.name("i").l) { case List(iIdentifier: Identifier) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + } + } + } + } + + "a generic, nested record pattern is matched" should { + + val cpg = code(""" + |package box; + | + |record Box(Pair value) {} + |record Pair(U first, V second) {} + | + |class Foo { + | void foo(Object o) { + | switch (o) { + | case Box(Pair(String s, Integer i)) -> { sink(s); sink(i); } + | default -> {} + | } + | } + |} + |""".stripMargin) + + "parse" in { + cpg.call.name("sink").isEmpty shouldBe false + } + + "have the correct lowering for the variable assignment" in { + val oParameter = cpg.method.name("foo").parameter.name("o").l + inside( + cpg.controlStructure + .controlStructureType(ControlStructureTypes.SWITCH) + .astChildren + .isBlock + .astChildren + .isControlStructure + .astChildren + .isBlock + .astChildren + .l + ) { case List(sLocal: Local, sAssign: Call, iLocal: Local, iAssign: Call, sSink: Call, iSink: Call) => + sLocal.name shouldBe "s" + sLocal.code shouldBe "String s" + sLocal.typeFullName shouldBe "java.lang.String" + + iLocal.name shouldBe "i" + iLocal.code shouldBe "Integer i" + iLocal.typeFullName shouldBe "java.lang.Integer" + + sAssign.name shouldBe Operators.assignment + sAssign.methodFullName shouldBe Operators.assignment + sAssign.typeFullName shouldBe "java.lang.String" + sAssign.code shouldBe "s = (String) $obj1" + + inside(sAssign.argument.l) { case List(sIdentifier: Identifier, stringCast: Call) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + + stringCast.name shouldBe Operators.cast + stringCast.methodFullName shouldBe Operators.cast + stringCast.typeFullName shouldBe "java.lang.String" + stringCast.code shouldBe "(String) $obj1" + + inside(stringCast.argument.l) { case List(stringType: TypeRef, tmpIdentifier1: Identifier) => + stringType.typeFullName shouldBe "java.lang.String" + stringType.code shouldBe "String" + + tmpIdentifier1.name shouldBe "$obj1" + tmpIdentifier1.code shouldBe "$obj1" + tmpIdentifier1.typeFullName shouldBe "java.lang.Object" + tmpIdentifier1.refsTo.l shouldBe cpg.local.nameExact("$obj1").l + } + } + + iAssign.name shouldBe Operators.assignment + iAssign.methodFullName shouldBe Operators.assignment + iAssign.typeFullName shouldBe "java.lang.Integer" + iAssign.code shouldBe "i = (Integer) $obj2" + + inside(iAssign.argument.l) { case List(iIdentifier: Identifier, integerCast: Call) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + + integerCast.name shouldBe Operators.cast + integerCast.methodFullName shouldBe Operators.cast + integerCast.typeFullName shouldBe "java.lang.Integer" + integerCast.code shouldBe "(Integer) $obj2" + + inside(integerCast.argument.l) { case List(integerType: TypeRef, tmpIdentifier2: Identifier) => + integerType.typeFullName shouldBe "java.lang.Integer" + integerType.code shouldBe "Integer" + + tmpIdentifier2.name shouldBe "$obj2" + tmpIdentifier2.code shouldBe "$obj2" + tmpIdentifier2.typeFullName shouldBe "java.lang.Object" + tmpIdentifier2.refsTo.l shouldBe cpg.local.nameExact("$obj2").l + } + } + + inside(sSink.argument.isIdentifier.name("s").l) { case List(sIdentifier: Identifier) => + sIdentifier.name shouldBe "s" + sIdentifier.code shouldBe "s" + sIdentifier.typeFullName shouldBe "java.lang.String" + sIdentifier.refsTo.l shouldBe List(sLocal) + } + + inside(iSink.argument.isIdentifier.name("i").l) { case List(iIdentifier: Identifier) => + iIdentifier.name shouldBe "i" + iIdentifier.code shouldBe "i" + iIdentifier.typeFullName shouldBe "java.lang.Integer" + iIdentifier.refsTo.l shouldBe List(iLocal) + } + } + } + } + } + + "unresolved patterns in instanceof expressions" when { + + "the pattern is a type pattern without an import fallback" should { + val cpg = code(""" + |class Foo { + | void foo(Object o) { + | if (o instanceof Bar b) { + | sink(b); + | } + | } + |} + |""".stripMargin) + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(oInstanceOfBar: Call) => + oInstanceOfBar.name shouldBe Operators.instanceOf + oInstanceOfBar.methodFullName shouldBe Operators.instanceOf + oInstanceOfBar.typeFullName shouldBe "boolean" + oInstanceOfBar.code shouldBe "o instanceof Bar" + + inside(oInstanceOfBar.argument.l) { case List(oIdentifier: Identifier, barType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.code shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + barType.typeFullName shouldBe "ANY" + barType.code shouldBe "Bar" + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(bLocal: Local, bAssign: Call, bSink: Call) => + bLocal.name shouldBe "b" + bLocal.code shouldBe "Bar b" + bLocal.typeFullName shouldBe "ANY" + + bAssign.name shouldBe Operators.assignment + bAssign.methodFullName shouldBe Operators.assignment + bAssign.typeFullName shouldBe "ANY" + bAssign.code shouldBe "b = (Bar) o" + + inside(bAssign.argument.l) { case List(bIdentifier: Identifier, castCall: Call) => + bIdentifier.name shouldBe "b" + bIdentifier.code shouldBe "b" + bIdentifier.typeFullName shouldBe "ANY" + bIdentifier.refsTo.l shouldBe List(bLocal) + + castCall.name shouldBe Operators.cast + castCall.methodFullName shouldBe Operators.cast + castCall.typeFullName shouldBe "ANY" + castCall.code shouldBe "(Bar) o" + + inside(castCall.argument.l) { case List(barType: TypeRef, oIdentifier: Identifier) => + barType.code shouldBe "Bar" + barType.typeFullName shouldBe "ANY" + + oIdentifier.name shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + } + } + + bSink.argument.isIdentifier.name("b").refsTo.l shouldBe List(bLocal) + } + } + } + + "the pattern is a type pattern with an import fallback" should { + val cpg = code(""" + |import bar.Bar; + | + |class Foo { + | void foo(Object o) { + | if (o instanceof Bar b) { + | sink(b); + | } + | } + |} + |""".stripMargin) + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(oInstanceOfBar: Call) => + oInstanceOfBar.name shouldBe Operators.instanceOf + oInstanceOfBar.methodFullName shouldBe Operators.instanceOf + oInstanceOfBar.typeFullName shouldBe "boolean" + oInstanceOfBar.code shouldBe "o instanceof Bar" + + inside(oInstanceOfBar.argument.l) { case List(oIdentifier: Identifier, barType: TypeRef) => + oIdentifier.name shouldBe "o" + oIdentifier.code shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + + barType.typeFullName shouldBe "bar.Bar" + barType.code shouldBe "Bar" + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(bLocal: Local, bAssign: Call, bSink: Call) => + bLocal.name shouldBe "b" + bLocal.code shouldBe "Bar b" + bLocal.typeFullName shouldBe "bar.Bar" + + bAssign.name shouldBe Operators.assignment + bAssign.methodFullName shouldBe Operators.assignment + bAssign.typeFullName shouldBe "bar.Bar" + bAssign.code shouldBe "b = (Bar) o" + + inside(bAssign.argument.l) { case List(bIdentifier: Identifier, castCall: Call) => + bIdentifier.name shouldBe "b" + bIdentifier.code shouldBe "b" + bIdentifier.typeFullName shouldBe "bar.Bar" + bIdentifier.refsTo.l shouldBe List(bLocal) + + castCall.name shouldBe Operators.cast + castCall.methodFullName shouldBe Operators.cast + castCall.typeFullName shouldBe "bar.Bar" + castCall.code shouldBe "(Bar) o" + + inside(castCall.argument.l) { case List(barType: TypeRef, oIdentifier: Identifier) => + barType.typeFullName shouldBe "bar.Bar" + + oIdentifier.name shouldBe "o" + oIdentifier.code shouldBe "o" + oIdentifier.typeFullName shouldBe "java.lang.Object" + oIdentifier.refsTo.l shouldBe cpg.method.name("foo").parameter.name("o").l + } + } + + bSink.argument.isIdentifier.name("b").refsTo.l shouldBe List(bLocal) + } + } + } + + "the pattern is a nested record pattern" should { + val cpg = code(""" + |class Foo { + | void foo(Object o) { + | if (o instanceof Bar(Baz(Qux q))) { + | sink(q); + | } + | } + |} + |""".stripMargin) + + "have the correct lowering for the type check" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.l) { + case List(firstAnd: Call) => + firstAnd.name shouldBe Operators.logicalAnd + firstAnd.methodFullName shouldBe Operators.logicalAnd + firstAnd.typeFullName shouldBe "boolean" + firstAnd.code shouldBe "(o instanceof Bar) && ((($obj0 = ((Bar) o).()) instanceof Baz) && (($obj1 = ((Baz) $obj0).()) instanceof Qux))" + + inside(firstAnd.argument.l) { case List(instanceOfBar: Call, secondAnd: Call) => + instanceOfBar.name shouldBe Operators.instanceOf + instanceOfBar.methodFullName shouldBe Operators.instanceOf + instanceOfBar.code shouldBe "o instanceof Bar" + instanceOfBar.typeFullName shouldBe "boolean" + + inside(secondAnd.argument.l) { case List(instanceOfBaz: Call, instanceOfQux: Call) => + instanceOfBaz.name shouldBe Operators.instanceOf + instanceOfBaz.methodFullName shouldBe Operators.instanceOf + instanceOfBaz.code shouldBe "($obj0 = ((Bar) o).()) instanceof Baz" + instanceOfBaz.typeFullName shouldBe "boolean" + + inside(instanceOfBaz.argument.l) { case List(tmpAssign: Call, bazType: TypeRef) => + inside(tmpAssign.argument.l) { case List(tmpIdentifier0: Identifier, fieldAccessor: Call) => + tmpIdentifier0.name shouldBe "$obj0" + tmpIdentifier0.code shouldBe "$obj0" + tmpIdentifier0.typeFullName shouldBe "ANY" + tmpIdentifier0.refsTo.l shouldBe cpg.local.nameExact("$obj0").l + + fieldAccessor.name shouldBe "" + fieldAccessor.methodFullName shouldBe ".Bar.:(0)" + fieldAccessor.typeFullName shouldBe "ANY" + fieldAccessor.code shouldBe "((Bar) o).()" + } + + bazType.typeFullName shouldBe "ANY" + bazType.code shouldBe "Baz" + } + } + } + } + } + + "have the correct lowering for the variable assignment" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).astChildren.isBlock.astChildren.l) { + case List(qLocal: Local, qAssign: Call, qSink: Call) => + qLocal.name shouldBe "q" + qLocal.code shouldBe "Qux q" + qLocal.typeFullName shouldBe "ANY" + + qAssign.name shouldBe Operators.assignment + qAssign.methodFullName shouldBe Operators.assignment + qAssign.typeFullName shouldBe "ANY" + qAssign.code shouldBe "q = (Qux) $obj1" + + inside(qAssign.argument.l) { case List(qIdentifier: Identifier, quxCast: Call) => + qIdentifier.name shouldBe "q" + qIdentifier.code shouldBe "q" + qIdentifier.typeFullName shouldBe "ANY" + qIdentifier.refsTo.l shouldBe List(qLocal) + + quxCast.name shouldBe Operators.cast + + inside(quxCast.argument.l) { case List(quxType: TypeRef, tmpIdentifier1: Identifier) => + quxType.code shouldBe "Qux" + + tmpIdentifier1.name shouldBe "$obj1" + tmpIdentifier1.code shouldBe "$obj1" + tmpIdentifier1.typeFullName shouldBe "ANY" + tmpIdentifier1.refsTo.l shouldBe cpg.local.nameExact("$obj1").l + } + } + } + } + } + } + +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/RecordTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/RecordTests.scala new file mode 100644 index 000000000000..a3d8d0166a5b --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/RecordTests.scala @@ -0,0 +1,609 @@ +package io.joern.javasrc2cpg.querying + +import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{ModifierTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal, Method, Return} +import io.shiftleft.semanticcpg.language.* + +class RecordTests extends JavaSrcCode2CpgFixture { + + "a record with a compact constructor" should { + val cpg = code(""" + |package foo; + | + |record Foo(String value) { + | public Foo { + | System.out.println(value); + | } + |} + |""".stripMargin) + + "extend java.lang.Record" in { + cpg.typeDecl("Foo").inheritsFromTypeFullName.l shouldBe List("java.lang.Record") + } + + "have the correct representation for the compact constructor" in { + inside(cpg.method.nameExact("").l) { case List(constructor) => + constructor.fullName shouldBe "foo.Foo.:void(java.lang.String)" + + inside(constructor.parameter.l) { case List(thisParam, valueParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + valueParam.name shouldBe "value" + valueParam.typeFullName shouldBe "java.lang.String" + + inside(constructor.body.astChildren.l) { case List(valueAssign: Call, printlnCall: Call) => + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.String" + valueAssign.code shouldBe "this.value = value" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueIdentifier: Identifier) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + + printlnCall.name shouldBe "println" + printlnCall.code shouldBe "System.out.println(value)" + inside(printlnCall.argument.l) { case List(_, valueIdentifier: Identifier) => + valueIdentifier.name shouldBe "value" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + } + } + } + } + + "have a private field for the parameter" in { + inside(cpg.member.l) { case List(valueMember) => + valueMember.name shouldBe "value" + valueMember.code shouldBe "String value" + valueMember.typeFullName shouldBe "java.lang.String" + valueMember.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) + } + } + + "have a public accessor method for the parameter" in { + inside(cpg.method.name("value").l) { case List(valueMethod: Method) => + valueMethod.name shouldBe "value" + valueMethod.fullName shouldBe "foo.Foo.value:java.lang.String()" + valueMethod.code shouldBe "public String value()" + valueMethod.lineNumber shouldBe Some(4) + valueMethod.columnNumber shouldBe Some(12) + + val methodReturn = valueMethod.methodReturn + methodReturn.typeFullName shouldBe "java.lang.String" + methodReturn.lineNumber shouldBe Some(4) + methodReturn.columnNumber shouldBe Some(12) + + inside(valueMethod.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + thisParam.lineNumber shouldBe Some(4) + thisParam.columnNumber shouldBe Some(12) + } + + inside(valueMethod.body.astChildren.l) { case List(returnStmt: Return) => + returnStmt.code shouldBe "return this.value" + returnStmt.lineNumber shouldBe Some(4) + returnStmt.columnNumber shouldBe Some(12) + + inside(returnStmt.astChildren.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + fieldAccess.lineNumber shouldBe Some(4) + fieldAccess.columnNumber shouldBe Some(12) + + inside(fieldAccess.argument.l) { case List(thisIdentifier: Identifier, fieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.code shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe cpg.method.name("value").parameter.l + thisIdentifier.lineNumber shouldBe Some(4) + thisIdentifier.columnNumber shouldBe Some(12) + + fieldIdentifier.canonicalName shouldBe "value" + fieldIdentifier.code shouldBe "value" + fieldIdentifier.lineNumber shouldBe Some(4) + fieldIdentifier.columnNumber shouldBe Some(12) + } + } + } + } + } + } + + "a record with an explicit non-canonical constructor" should { + val cpg = code(""" + |package foo; + | + |record Foo(String value) { + | public Foo() { + | this.value = "value"; + | } + |} + |""".stripMargin) + + "have the correct constructors" in { + inside(cpg.method.nameExact("").sortBy(_.parameter.size).l) { + case List(explicitConstructor, canonicalConstructor) => + explicitConstructor.fullName shouldBe "foo.Foo.:void()" + + inside(explicitConstructor.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + inside(explicitConstructor.body.astChildren.l) { case List(valueAssign: Call) => + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.String" + valueAssign.code shouldBe "this.value = \"value\"" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueLiteral: Literal) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueLiteral.typeFullName shouldBe "java.lang.String" + valueLiteral.code shouldBe "\"value\"" + } + } + } + + canonicalConstructor.fullName shouldBe "foo.Foo.:void(java.lang.String)" + + inside(canonicalConstructor.parameter.l) { case List(thisParam, valueParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + valueParam.name shouldBe "value" + valueParam.typeFullName shouldBe "java.lang.String" + + inside(canonicalConstructor.body.astChildren.l) { case List(valueAssign: Call) => + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.String" + valueAssign.code shouldBe "this.value = value" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueIdentifier: Identifier) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + } + } + } + } + + "have a private field for the parameter" in { + inside(cpg.member.l) { case List(valueMember) => + valueMember.name shouldBe "value" + valueMember.code shouldBe "String value" + valueMember.typeFullName shouldBe "java.lang.String" + valueMember.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) + } + } + + "have a public accessor method for the parameter" in { + inside(cpg.method.name("value").l) { case List(valueMethod: Method) => + valueMethod.name shouldBe "value" + valueMethod.fullName shouldBe "foo.Foo.value:java.lang.String()" + valueMethod.code shouldBe "public String value()" + valueMethod.lineNumber shouldBe Some(4) + valueMethod.columnNumber shouldBe Some(12) + + val methodReturn = valueMethod.methodReturn + methodReturn.typeFullName shouldBe "java.lang.String" + methodReturn.lineNumber shouldBe Some(4) + methodReturn.columnNumber shouldBe Some(12) + + inside(valueMethod.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + thisParam.lineNumber shouldBe Some(4) + thisParam.columnNumber shouldBe Some(12) + } + + inside(valueMethod.body.astChildren.l) { case List(returnStmt: Return) => + returnStmt.code shouldBe "return this.value" + returnStmt.lineNumber shouldBe Some(4) + returnStmt.columnNumber shouldBe Some(12) + + inside(returnStmt.astChildren.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + fieldAccess.lineNumber shouldBe Some(4) + fieldAccess.columnNumber shouldBe Some(12) + + inside(fieldAccess.argument.l) { case List(thisIdentifier: Identifier, fieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.code shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe cpg.method.name("value").parameter.l + thisIdentifier.lineNumber shouldBe Some(4) + thisIdentifier.columnNumber shouldBe Some(12) + + fieldIdentifier.canonicalName shouldBe "value" + fieldIdentifier.code shouldBe "value" + fieldIdentifier.lineNumber shouldBe Some(4) + fieldIdentifier.columnNumber shouldBe Some(12) + } + } + } + } + } + } + + "a record with an explicit canonical constructor" should { + val cpg = code(""" + |package foo; + | + |record Foo(String value) { + | public Foo(String value) { + | System.out.println(value); + | this.value = value; + | } + |} + |""".stripMargin) + + "have the correct constructor" in { + inside(cpg.method.nameExact("").l) { case List(constructor) => + constructor.fullName shouldBe "foo.Foo.:void(java.lang.String)" + + inside(constructor.parameter.l) { case List(thisParam, valueParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + valueParam.name shouldBe "value" + valueParam.typeFullName shouldBe "java.lang.String" + + inside(constructor.body.astChildren.l) { case List(printlnCall: Call, valueAssign: Call) => + printlnCall.name shouldBe "println" + printlnCall.code shouldBe "System.out.println(value)" + + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.String" + valueAssign.code shouldBe "this.value = value" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueIdentifier: Identifier) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + } + } + } + } + + "have a private field for the parameter" in { + inside(cpg.member.l) { case List(valueMember) => + valueMember.name shouldBe "value" + valueMember.code shouldBe "String value" + valueMember.typeFullName shouldBe "java.lang.String" + valueMember.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) + } + } + + "have a public accessor method for the parameter" in { + inside(cpg.method.name("value").l) { case List(valueMethod: Method) => + valueMethod.name shouldBe "value" + valueMethod.fullName shouldBe "foo.Foo.value:java.lang.String()" + valueMethod.code shouldBe "public String value()" + valueMethod.lineNumber shouldBe Some(4) + valueMethod.columnNumber shouldBe Some(12) + + val methodReturn = valueMethod.methodReturn + methodReturn.typeFullName shouldBe "java.lang.String" + methodReturn.lineNumber shouldBe Some(4) + methodReturn.columnNumber shouldBe Some(12) + + inside(valueMethod.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + thisParam.lineNumber shouldBe Some(4) + thisParam.columnNumber shouldBe Some(12) + } + + inside(valueMethod.body.astChildren.l) { case List(returnStmt: Return) => + returnStmt.code shouldBe "return this.value" + returnStmt.lineNumber shouldBe Some(4) + returnStmt.columnNumber shouldBe Some(12) + + inside(returnStmt.astChildren.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + fieldAccess.lineNumber shouldBe Some(4) + fieldAccess.columnNumber shouldBe Some(12) + + inside(fieldAccess.argument.l) { case List(thisIdentifier: Identifier, fieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.code shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe cpg.method.name("value").parameter.l + thisIdentifier.lineNumber shouldBe Some(4) + thisIdentifier.columnNumber shouldBe Some(12) + + fieldIdentifier.canonicalName shouldBe "value" + fieldIdentifier.code shouldBe "value" + fieldIdentifier.lineNumber shouldBe Some(4) + fieldIdentifier.columnNumber shouldBe Some(12) + } + } + } + } + } + } + + "a record with a generic parameter" should { + val cpg = code(""" + |package foo; + | + |record Foo(T value) {} + |""".stripMargin) + + "have the correct default canonical constructor" in { + inside(cpg.method.nameExact("").l) { case List(constructor) => + constructor.fullName shouldBe "foo.Foo.:void(java.lang.Object)" + + inside(constructor.parameter.l) { case List(thisParam, valueParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + valueParam.name shouldBe "value" + valueParam.typeFullName shouldBe "java.lang.Object" + + inside(constructor.body.astChildren.l) { case List(valueAssign: Call) => + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.Object" + valueAssign.code shouldBe "this.value = value" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueIdentifier: Identifier) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.Object" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.Object" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + } + } + } + } + + "have a private field for the parameter" in { + inside(cpg.member.l) { case List(valueMember) => + valueMember.name shouldBe "value" + valueMember.code shouldBe "T value" + valueMember.typeFullName shouldBe "java.lang.Object" + valueMember.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) + } + } + + "have a public accessor method for the parameter" in { + inside(cpg.method.name("value").l) { case List(valueMethod: Method) => + valueMethod.name shouldBe "value" + valueMethod.fullName shouldBe "foo.Foo.value:java.lang.Object()" + valueMethod.code shouldBe "public T value()" + valueMethod.lineNumber shouldBe Some(4) + valueMethod.columnNumber shouldBe Some(15) + + val methodReturn = valueMethod.methodReturn + methodReturn.typeFullName shouldBe "java.lang.Object" + methodReturn.lineNumber shouldBe Some(4) + methodReturn.columnNumber shouldBe Some(15) + + inside(valueMethod.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + thisParam.lineNumber shouldBe Some(4) + thisParam.columnNumber shouldBe Some(15) + } + + inside(valueMethod.body.astChildren.l) { case List(returnStmt: Return) => + returnStmt.code shouldBe "return this.value" + returnStmt.lineNumber shouldBe Some(4) + returnStmt.columnNumber shouldBe Some(15) + + inside(returnStmt.astChildren.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.Object" + fieldAccess.lineNumber shouldBe Some(4) + fieldAccess.columnNumber shouldBe Some(15) + + inside(fieldAccess.argument.l) { case List(thisIdentifier: Identifier, fieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.code shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe cpg.method.name("value").parameter.l + thisIdentifier.lineNumber shouldBe Some(4) + thisIdentifier.columnNumber shouldBe Some(15) + + fieldIdentifier.canonicalName shouldBe "value" + fieldIdentifier.code shouldBe "value" + fieldIdentifier.lineNumber shouldBe Some(4) + fieldIdentifier.columnNumber shouldBe Some(15) + } + } + } + } + } + } + + "a simple record with no explicit body" should { + val cpg = code(""" + |package foo; + | + |record Foo(String value) {} + |""".stripMargin) + + "have the correct default canonical constructor" in { + inside(cpg.method.nameExact("").l) { case List(constructor) => + constructor.fullName shouldBe "foo.Foo.:void(java.lang.String)" + + inside(constructor.parameter.l) { case List(thisParam, valueParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + + valueParam.name shouldBe "value" + valueParam.typeFullName shouldBe "java.lang.String" + + inside(constructor.body.astChildren.l) { case List(valueAssign: Call) => + valueAssign.name shouldBe Operators.assignment + valueAssign.methodFullName shouldBe Operators.assignment + valueAssign.typeFullName shouldBe "java.lang.String" + valueAssign.code shouldBe "this.value = value" + + inside(valueAssign.argument.l) { case List(fieldAccess: Call, valueIdentifier: Identifier) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + + inside(fieldAccess.argument.l) { + case List(thisIdentifier: Identifier, valueFieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe List(thisParam) + + valueFieldIdentifier.canonicalName shouldBe "value" + } + + valueIdentifier.name shouldBe "value" + valueIdentifier.typeFullName shouldBe "java.lang.String" + valueIdentifier.refsTo.l shouldBe List(valueParam) + } + } + } + } + } + + "have a private field for the parameter" in { + inside(cpg.member.l) { case List(valueMember) => + valueMember.name shouldBe "value" + valueMember.code shouldBe "String value" + valueMember.typeFullName shouldBe "java.lang.String" + valueMember.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) + } + } + + "have a public accessor method for the parameter" in { + inside(cpg.method.name("value").l) { case List(valueMethod: Method) => + valueMethod.name shouldBe "value" + valueMethod.fullName shouldBe "foo.Foo.value:java.lang.String()" + valueMethod.code shouldBe "public String value()" + valueMethod.lineNumber shouldBe Some(4) + valueMethod.columnNumber shouldBe Some(12) + + val methodReturn = valueMethod.methodReturn + methodReturn.typeFullName shouldBe "java.lang.String" + methodReturn.lineNumber shouldBe Some(4) + methodReturn.columnNumber shouldBe Some(12) + + inside(valueMethod.parameter.l) { case List(thisParam) => + thisParam.name shouldBe "this" + thisParam.typeFullName shouldBe "foo.Foo" + thisParam.lineNumber shouldBe Some(4) + thisParam.columnNumber shouldBe Some(12) + } + + inside(valueMethod.body.astChildren.l) { case List(returnStmt: Return) => + returnStmt.code shouldBe "return this.value" + returnStmt.lineNumber shouldBe Some(4) + returnStmt.columnNumber shouldBe Some(12) + + inside(returnStmt.astChildren.l) { case List(fieldAccess: Call) => + fieldAccess.name shouldBe Operators.fieldAccess + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.code shouldBe "this.value" + fieldAccess.typeFullName shouldBe "java.lang.String" + fieldAccess.lineNumber shouldBe Some(4) + fieldAccess.columnNumber shouldBe Some(12) + + inside(fieldAccess.argument.l) { case List(thisIdentifier: Identifier, fieldIdentifier: FieldIdentifier) => + thisIdentifier.name shouldBe "this" + thisIdentifier.code shouldBe "this" + thisIdentifier.typeFullName shouldBe "foo.Foo" + thisIdentifier.refsTo.l shouldBe cpg.method.name("value").parameter.l + thisIdentifier.lineNumber shouldBe Some(4) + thisIdentifier.columnNumber shouldBe Some(12) + + fieldIdentifier.canonicalName shouldBe "value" + fieldIdentifier.code shouldBe "value" + fieldIdentifier.lineNumber shouldBe Some(4) + fieldIdentifier.columnNumber shouldBe Some(12) + } + } + } + } + } + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ScopeTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ScopeTests.scala index 34944d78812c..ebaf3bd36d86 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ScopeTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ScopeTests.scala @@ -2,8 +2,8 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, TypeRef} +import io.shiftleft.semanticcpg.language.* class ScopeTests extends JavaSrcCode2CpgFixture { @@ -151,9 +151,8 @@ class ScopeTests extends JavaSrcCode2CpgFixture { fieldAccess.methodFullName shouldBe Operators.fieldAccess fieldAccess.typeFullName shouldBe "java.lang.Object" fieldAccess.argument.l match { - case List(identifier: Identifier, fieldIdentifier: FieldIdentifier) => - identifier.name shouldBe "Test" - identifier.typeFullName shouldBe "Test" + case List(typeRef: TypeRef, fieldIdentifier: FieldIdentifier) => + typeRef.typeFullName shouldBe "Test" fieldIdentifier.canonicalName shouldBe "staticO" case res => fail(s"Expected field access args but got $res") } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SpecialOperatorTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SpecialOperatorTests.scala index 9dbf0dfc7380..99a6dd5520d0 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SpecialOperatorTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SpecialOperatorTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, TypeRef} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SpecialOperatorTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SynchronizedTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SynchronizedTests.scala index 66eeb14c0d67..48b215d300c8 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SynchronizedTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/SynchronizedTests.scala @@ -11,7 +11,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ Modifier, Return } -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SynchronizedTests extends JavaSrcCode2CpgFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeDeclTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeDeclTests.scala index 3d4857c67aca..8e287021afb9 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeDeclTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeDeclTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.Return -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.File diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeFallbackTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeFallbackTests.scala new file mode 100644 index 000000000000..ac2396308fd5 --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeFallbackTests.scala @@ -0,0 +1,98 @@ +package io.joern.javasrc2cpg.queryinfieldsg + +import io.joern.javasrc2cpg.Config +import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} +import io.shiftleft.semanticcpg.language.* + +class TypeFallbackTests extends JavaSrcCode2CpgFixture { + private val typeFallbackDisabledConfig = Config().withDisableTypeFallback(true) + + "cpgs generated with type fallbacks disabled" should { + + "set the type of unresolved locals to .Type" in { + val cpg = code(""" + |class Foo { + | void foo() { + | Bar b = new Bar(); + | } + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + cpg.call(".*alloc.*").typeFullName.l shouldBe List(".Bar") + cpg.method("foo").local.name("b").typeFullName.l shouldBe List(".Bar") + } + + "set the type of call receivers to .Type" in { + val cpg = code(""" + |class Foo { + | void foo(Bar b) { + | b.bar(); + | } + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + inside(cpg.call.name("bar").l) { case List(barCall: Call) => + barCall.methodFullName shouldBe ".Bar.bar:(0)" + + inside(barCall.receiver.l) { case List(bIdentifier: Identifier) => + bIdentifier.name shouldBe "b" + bIdentifier.typeFullName shouldBe ".Bar" + } + } + } + + "set the type of unresolved parameters to .Type" in { + val cpg = code(""" + |class Foo { + | void foo(Bar b) { } + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + cpg.method("foo").parameter.name("b").typeFullName.l shouldBe List(".Bar") + } + + "set the type of unresolved fields to .Type" in { + val cpg = code(""" + |class Foo { + | Bar b; + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + cpg.member("b").typeFullName.l shouldBe List(".Bar") + } + + "set the type of unresolved pattern variables to .Type" in { + val cpg = code(""" + |class Foo { + | void foo(Object o) { + | if (o instanceof Bar b) {} + | } + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + cpg.method("foo").local.name("b").typeFullName.l shouldBe List(".Bar") + } + + "not use wildcard imports as a fallback" in { + val cpg = code(""" + |import testpackage.*; + | + |class Foo { + | void foo() { + | Bar b = new Bar(); + | } + |} + |""".stripMargin) + .withConfig(typeFallbackDisabledConfig) + + cpg.call(".*alloc.*").typeFullName.l shouldBe List(".Bar") + cpg.method("foo").local.name("b").typeFullName.l shouldBe List(".Bar") + } + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala index 5f2562cbd982..ce3a116dd730 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala @@ -4,7 +4,7 @@ import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeTests.scala index 7ef189482af9..1e37c6ed4139 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeTests.scala @@ -5,7 +5,7 @@ import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} import io.shiftleft.proto.cpg.Cpg.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NewTypeTests extends JavaSrcCode2CpgFixture { "processing wildcard types should not crash (smoke test)" when { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ArrayTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ArrayTests.scala index 4dc784776b8d..e146937fda67 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ArrayTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ArrayTests.scala @@ -1,9 +1,9 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class ArrayTests extends JavaDataflowFixture { behavior of "Dataflow through arrays" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/FunctionCallTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/FunctionCallTests.scala index bb3b4b15a230..53a5570b7d6b 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/FunctionCallTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/FunctionCallTests.scala @@ -1,8 +1,8 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.{JavaDataflowFixture, JavaSrcCode2CpgFixture} -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* class NewFunctionCallTests extends JavaSrcCode2CpgFixture(withOssDataflow = true) { "Dataflow through function calls" should { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/IfTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/IfTests.scala index 0a85145fd59f..f5384cab3c8c 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/IfTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/IfTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* class IfTests extends JavaDataflowFixture { behavior of "Dataflow through IF structures" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LambdaTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LambdaTests.scala index 2ae5354391e6..c18f615eedcc 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LambdaTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LambdaTests.scala @@ -1,8 +1,8 @@ package io.joern.javasrc2cpg.querying.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LambdaTests extends JavaSrcCode2CpgFixture(withOssDataflow = true) { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LoopTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LoopTests.scala index 556a263e0588..d664633ca5ef 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LoopTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/LoopTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* class LoopTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MemberTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MemberTests.scala index 2d104271385c..76a9cc9d6c2f 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MemberTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MemberTests.scala @@ -1,8 +1,8 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.{JavaDataflowFixture, JavaSrcCode2CpgFixture} -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* class MemberTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MethodReturnTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MethodReturnTests.scala index 1f39c507cbef..e1db74388037 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MethodReturnTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/MethodReturnTests.scala @@ -1,8 +1,8 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* class MethodReturnTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ObjectTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ObjectTests.scala index babcb1667967..cd26a4015e0b 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ObjectTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ObjectTests.scala @@ -1,8 +1,41 @@ package io.joern.javasrc2cpg.querying.dataflow -import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ +import io.joern.javasrc2cpg.testfixtures.{JavaDataflowFixture, JavaSrcCode2CpgFixture} +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* + +class NewObjectTests extends JavaSrcCode2CpgFixture(withOssDataflow = true) { + + "static field passed as an argument inside a same-class static method whilst being referenced by its simple name" in { + val cpg = code(""" + |class Bar { + | static String CONST = ""; + | static void run() { + | System.out.println(CONST); + | } + |}""".stripMargin) + val sink = cpg.call("println").argument(1) + val source = cpg.literal + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("String Bar.CONST = \"\"", Some(3)), ("System.out.println(CONST)", Some(5))) + ) + } + + "static field passed as an argument inside a same-class static method whilst being referenced by its qualified name" in { + val cpg = code(""" + |class Bar { + | static String CONST = ""; + | static void run() { + | System.out.println(Bar.CONST); + | } + |}""".stripMargin) + val sink = cpg.call("println").argument(1) + val source = cpg.literal + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("String Bar.CONST = \"\"", Some(3)), ("System.out.println(Bar.CONST)", Some(5))) + ) + } +} class ObjectTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/OperatorTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/OperatorTests.scala index 6a463d4e3819..b549f8ce6e50 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/OperatorTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/OperatorTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* class OperatorTests extends JavaDataflowFixture { behavior of "Dataflow through operators" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ReturnTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ReturnTests.scala index 6125def15d45..d2fd5859aec4 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ReturnTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ReturnTests.scala @@ -1,8 +1,8 @@ package io.joern.javasrc2cpg.querying.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ReturnTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SemanticTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SemanticTests.scala index 8401b458b68d..e1dd18ac1d74 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SemanticTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SemanticTests.scala @@ -1,22 +1,24 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.{EngineContext, EngineConfig} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.semanticsloader.FlowSemantic import io.joern.x2cpg.Defines class SemanticTests - extends JavaDataflowFixture(extraFlows = - List( - FlowSemantic.from("Test.sanitize:java.lang.String(java.lang.String)", List((0, 0), (1, 1))), - FlowSemantic.from(s"ext.Library.killParam:${Defines.UnresolvedSignature}(1)", List.empty), - FlowSemantic.from("^ext\\.Library\\.taintNone:.*", List((0, 0), (1, 1)), regex = true), - FlowSemantic.from("^ext\\.Library\\.taint1to2:.*", List((1, 2)), regex = true) + extends JavaDataflowFixture(semantics = + DefaultSemantics().plus( + List( + FlowSemantic.from("Test.sanitize:java.lang.String(java.lang.String)", List((0, 0), (1, 1))), + FlowSemantic.from(s"ext.Library.killParam:${Defines.UnresolvedSignature}(1)", List.empty), + FlowSemantic.from("^ext\\.Library\\.taintNone:.*", List((0, 0), (1, 1)), regex = true), + FlowSemantic.from("^ext\\.Library\\.taint1to2:.*", List((1, 2)), regex = true) + ) ) ) { behavior of "Dataflow through custom semantics" diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SwitchTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SwitchTests.scala index 89131994d336..ba82845d22f1 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SwitchTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/SwitchTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* class SwitchTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/TryTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/TryTests.scala index 9f205a153558..8939c875e2a3 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/TryTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/TryTests.scala @@ -1,7 +1,7 @@ package io.joern.javasrc2cpg.querying.dataflow import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* class TryTests extends JavaDataflowFixture { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaDataflowFixture.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaDataflowFixture.scala index ec6881f697e3..60c43e6555fd 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaDataflowFixture.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaDataflowFixture.scala @@ -1,22 +1,22 @@ package io.joern.javasrc2cpg.testfixtures +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic -import io.joern.javasrc2cpg.JavaSrc2CpgTestContext +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Literal} import io.shiftleft.semanticcpg.language.* import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -class JavaDataflowFixture(extraFlows: List[FlowSemantic] = List.empty) extends AnyFlatSpec with Matchers { +class JavaDataflowFixture(semantics: Semantics = DefaultSemantics()) extends AnyFlatSpec with Matchers { implicit val resolver: ICallResolver = NoResolve implicit lazy val engineContext: EngineContext = EngineContext() val code: String = "" - lazy val cpg: Cpg = JavaSrc2CpgTestContext.buildCpgWithDataflow(code, extraFlows = extraFlows) + lazy val cpg: Cpg = JavaSrcTestCpg().withOssDataflow().withSemantics(semantics).moreCode(code) def getConstSourceSink( methodName: String, diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaSrcCodeToCpgFixture.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaSrcCodeToCpgFixture.scala index 717ce7bc3d95..c21f7e35634a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaSrcCodeToCpgFixture.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/testfixtures/JavaSrcCodeToCpgFixture.scala @@ -1,13 +1,13 @@ package io.joern.javasrc2cpg.testfixtures -import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.language.Path +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} import io.joern.javasrc2cpg.{Config, JavaSrc2Cpg} -import io.joern.x2cpg.X2Cpg import io.joern.x2cpg.frontendspecific.javasrc2cpg import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig -import io.joern.x2cpg.testfixtures.{Code2CpgFixture, DefaultTestCpg, LanguageFrontend, TestCpg} +import io.joern.x2cpg.testfixtures.{Code2CpgFixture, DefaultTestCpg, LanguageFrontend} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Literal} import io.shiftleft.semanticcpg.language.* @@ -42,12 +42,12 @@ class JavaSrcTestCpg(enableTypeRecovery: Boolean = false) class JavaSrcCode2CpgFixture( withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), enableTypeRecovery: Boolean = false ) extends Code2CpgFixture(() => - new JavaSrcTestCpg(enableTypeRecovery).withOssDataflow(withOssDataflow).withExtraFlows(extraFlows) + new JavaSrcTestCpg(enableTypeRecovery).withOssDataflow(withOssDataflow).withSemantics(semantics) ) - with SemanticCpgTestFixture(extraFlows) { + with SemanticCpgTestFixture(semantics) { implicit val resolver: ICallResolver = NoResolve @@ -82,4 +82,6 @@ class JavaSrcCode2CpgFixture( (source, sink) } + + protected def flowToResultPairs(path: Path): List[(String, Option[Int])] = path.resultPairs() } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/util/ScopeTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/util/ScopeTests.scala index f2217fc6e7d3..40260e819438 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/util/ScopeTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/util/ScopeTests.scala @@ -14,8 +14,13 @@ import io.joern.javasrc2cpg.scope.Scope.CapturedVariable import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn import io.joern.javasrc2cpg.scope.Scope.ScopeParameter import io.joern.javasrc2cpg.scope.Scope.NotInScope +import io.joern.x2cpg.ValidationMode class ScopeTests extends AnyFlatSpec with Matchers with BeforeAndAfterAll { + private implicit val withSchemaValidation: ValidationMode = ValidationMode.Enabled + private implicit val disableTypeFallback: Boolean = false + private val genericSignature = "GENERIC_SIGNATURE" + behavior of "javasrc2cpg scope" it should "find a simple variable for a member" in { @@ -24,7 +29,7 @@ class ScopeTests extends AnyFlatSpec with Matchers with BeforeAndAfterAll { val method = NewMethod().name("fooMethod") val scope = new Scope() - scope.pushTypeDeclScope(typeDecl, isStatic = false) + scope.pushTypeDeclScope(typeDecl, isStatic = false, Option(genericSignature)) scope.enclosingTypeDecl.get.addMember(member, isStatic = false) scope.pushMethodScope(method, ExpectedType.empty, isStatic = false) @@ -55,10 +60,13 @@ class ScopeTests extends AnyFlatSpec with Matchers with BeforeAndAfterAll { val scope = new Scope() scope.pushTypeDeclScope(outerTypeDecl, isStatic = false) scope.pushMethodScope(method, ExpectedType.empty, isStatic = false) - scope.enclosingMethod.get.addParameter(outerParameter) + scope.enclosingMethod.get.addParameter(outerParameter, genericSignature) scope.pushTypeDeclScope(innerTypeDecl, isStatic = false) - scope.lookupVariable("fooParameter") shouldBe CapturedVariable(Nil, ScopeParameter(outerParameter)) + scope.lookupVariable("fooParameter") shouldBe CapturedVariable( + Nil, + ScopeParameter(outerParameter, genericSignature) + ) } it should "find a capture chain for a captured variable in an outer-outer scope" in { @@ -72,14 +80,14 @@ class ScopeTests extends AnyFlatSpec with Matchers with BeforeAndAfterAll { val scope = new Scope() scope.pushTypeDeclScope(outerOuterTypeDecl, isStatic = false) scope.pushMethodScope(outerOuterMethod, ExpectedType.empty, isStatic = false) - scope.enclosingMethod.get.addParameter(outerOuterParameter) + scope.enclosingMethod.get.addParameter(outerOuterParameter, genericSignature) scope.pushTypeDeclScope(outerTypeDecl, isStatic = false) scope.pushMethodScope(outerMethod, ExpectedType.empty, isStatic = false) scope.pushTypeDeclScope(innerTypeDecl, isStatic = false) scope.lookupVariable("parameter") shouldBe CapturedVariable( List(outerTypeDecl), - ScopeParameter(outerOuterParameter) + ScopeParameter(outerOuterParameter, genericSignature) ) } @@ -94,7 +102,7 @@ class ScopeTests extends AnyFlatSpec with Matchers with BeforeAndAfterAll { val scope = new Scope() scope.pushTypeDeclScope(outerOuterTypeDecl, isStatic = false) scope.pushMethodScope(outerOuterMethod, ExpectedType.empty, isStatic = false) - scope.enclosingMethod.get.addParameter(outerOuterParameter) + scope.enclosingMethod.get.addParameter(outerOuterParameter, genericSignature) scope.pushTypeDeclScope(outerTypeDecl, isStatic = false) scope.pushMethodScope(outerMethod, ExpectedType.empty, isStatic = false) scope.pushTypeDeclScope(innerTypeDecl, isStatic = true) diff --git a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/Main.scala b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/Main.scala index cc383563ea52..20421f73f158 100644 --- a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/Main.scala +++ b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/Main.scala @@ -1,9 +1,12 @@ package io.joern.jimple2cpg -import io.joern.jimple2cpg.Frontend._ +import io.joern.jimple2cpg.Frontend.* import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser +import java.util.concurrent.ExecutorService + /** Command line configuration parameters */ final case class Config( @@ -74,8 +77,14 @@ private object Frontend { /** Entry point for command line CPG creator */ -object Main extends X2CpgMain(cmdLineParser, new Jimple2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new Jimple2Cpg()) with FrontendHTTPServer[Config, Jimple2Cpg] { + + override protected def newDefaultConfig(): Config = Config() + + override protected val executor: ExecutorService = FrontendHTTPServer.singleThreadExecutor() + def run(config: Config, jimple2Cpg: Jimple2Cpg): Unit = { - jimple2Cpg.run(config) + if (config.serverMode) { startup() } + else { jimple2Cpg.run(config) } } } diff --git a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/AstCreator.scala index ca9600e2cba9..6f248ffa494b 100644 --- a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/AstCreator.scala @@ -11,7 +11,7 @@ import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.* import org.objectweb.asm.Type import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import soot.jimple.* import soot.tagkit.* import soot.{Unit as SUnit, Local as _, *} diff --git a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForDeclarationsCreator.scala b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForDeclarationsCreator.scala index a2af7244d5f3..e9286134bce8 100644 --- a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForDeclarationsCreator.scala +++ b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForDeclarationsCreator.scala @@ -11,6 +11,7 @@ import soot.tagkit.* import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters.CollectionHasAsScala + trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) extends AstForTypeDeclsCreator with AstForMethodsCreator { this: AstCreator => diff --git a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForMethodsCreator.scala b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForMethodsCreator.scala index 13b8534c1618..d499ec3d4e35 100644 --- a/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForMethodsCreator.scala +++ b/joern-cli/frontends/jimple2cpg/src/main/scala/io/joern/jimple2cpg/astcreation/declarations/AstForMethodsCreator.scala @@ -109,14 +109,38 @@ trait AstForMethodsCreator(implicit withSchemaValidation: ValidationMode) { this bodyStatementsInfo.targets.foreach { case (asts, unit) => asts.headOption match { case Some(value) => - diffGraph.addEdge(value.root.get, bodyStatementsInfo.unitToAsts(unit).last.root.get, EdgeTypes.CFG) + bodyStatementsInfo.unitToAsts.get(unit) match { + case Some(targetAsts) if targetAsts.nonEmpty => + diffGraph.addEdge(value.root.get, targetAsts.last.root.get, EdgeTypes.CFG) + case _ => + logger.error( + s"AstForMethodsCreator: Missing unit in unitToAsts: $unit (${unit.getClass.getSimpleName})" + ) + } case None => + logger.error("AstForMethodsCreator: Empty asts list for target") } } + bodyStatementsInfo.edges.foreach { case (a, b) => - val aNode = bodyStatementsInfo.unitToAsts(a).last.root.get - val bNode = bodyStatementsInfo.unitToAsts(b).last.root.get - diffGraph.addEdge(aNode, bNode, EdgeTypes.CFG) + (bodyStatementsInfo.unitToAsts.get(a), bodyStatementsInfo.unitToAsts.get(b)) match { + case (Some(aAsts), Some(bAsts)) if aAsts.nonEmpty && bAsts.nonEmpty => + val aNode = aAsts.last.root.get + val bNode = bAsts.last.root.get + diffGraph.addEdge(aNode, bNode, EdgeTypes.CFG) + case _ => + logger.error( + s"AstForMethodsCreator: Failed to add CFG edge between units: " + + s"a=${a.getClass.getSimpleName} (${a.toString.take(50)}) " + + s"b=${b.getClass.getSimpleName} (${b.toString.take(50)})" + ) + if (bodyStatementsInfo.unitToAsts.get(a).isEmpty) { + logger.debug(s"AstForMethodsCreator: Missing source unit in unitToAsts: $a (${a.getClass.getSimpleName})") + } + if (bodyStatementsInfo.unitToAsts.get(b).isEmpty) { + logger.debug(s"AstForMethodsCreator: Missing target unit in unitToAsts: $b (${b.getClass.getSimpleName})") + } + } } } } @@ -230,7 +254,11 @@ trait AstForMethodsCreator(implicit withSchemaValidation: ValidationMode) { this val trapStack = new mutable.Stack[soot.Trap]; body.getUnits.asScala.filterNot(isIgnoredUnit).foreach { statement => // Remove traps that ended on the previous unit - (1 to popTraps.getOrElse(statement, 0)).foreach(_ => trapStack.pop) + (1 to popTraps.getOrElse(statement, 0)).foreach { _ => + if (trapStack.nonEmpty) { + trapStack.pop() + } + } // Add traps that apply to this unit pushTraps.getOrElse(statement, List.empty).foreach(trapStack.push) @@ -246,7 +274,11 @@ trait AstForMethodsCreator(implicit withSchemaValidation: ValidationMode) { this stack.push(stack.pop().withChildren(asts)) } - stack.pop() + if (stack.nonEmpty) { + stack.pop() + } else { + Ast(blockNode(body)) + } } } diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/config/ConfigTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/config/ConfigTests.scala index 7f54c794d083..a122c5938a09 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/io/Jimple2CpgHTTPServerTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/io/Jimple2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..89d8b78aa59d --- /dev/null +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/io/Jimple2CpgHTTPServerTests.scala @@ -0,0 +1,88 @@ +package io.joern.jimple2cpg.io + +import better.files.File +import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture +import io.joern.jimple2cpg.testfixtures.JimpleCodeToCpgFixture +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class Jimple2CpgHTTPServerTests extends JimpleCode2CpgFixture with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("jimple2cpgTestsHttpTest") + val file = dir / "main.java" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |class Foo { + | static void main$indexStr(int argc, char argv) { + | System.out.println("Hello World!"); + | } + |} + |""".stripMargin) + JimpleCodeToCpgFixture.compileJava(dir.path, List(file.toJava)) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.jimple2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.jimple2cpg.Main.stop() + } + + "Using jimple2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("jimple2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("""$stack2.println("Hello World!")""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("jimple2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain(s"main$index") + cpg.call.code.l should contain("""$stack2.println("Hello World!")""") + } + } + } + } + +} diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/AnnotationTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/AnnotationTests.scala index 11bb1da6a974..dd4e5c7f83fe 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/AnnotationTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/AnnotationTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Annotation, AnnotationLiteral, ArrayInitializer} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AnnotationTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ArrayTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ArrayTests.scala index a7ccd178c297..cd97abf470bf 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ArrayTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ArrayTests.scala @@ -4,7 +4,7 @@ import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Failed class ArrayTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CfgTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CfgTests.scala index 3dee0d91e31e..57325ce9c37f 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CfgTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CfgTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CfgTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CodeDumperTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CodeDumperTests.scala index 918b20fe503d..08d772ddac75 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CodeDumperTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/CodeDumperTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.Config import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CodeDumperTests extends JimpleCode2CpgFixture { private val config = Config().withDisableFileContent(false) diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ConstructorInvocationTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ConstructorInvocationTests.scala index 38ddf83d54c1..9bca4df2876c 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ConstructorInvocationTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ConstructorInvocationTests.scala @@ -3,9 +3,9 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.proto.cpg.Cpg.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** These tests are based off of those found in javasrc2cpg but modified to fit to Jimple's 3-address code rule and flat * AST. diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/EnumTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/EnumTests.scala index a1204ba6a8e0..2c5deae1521f 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/EnumTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/EnumTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.Literal -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class EnumTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FieldAccessTests.scala index 00835ff01fff..233c5a3a65e9 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FieldAccessTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FileTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FileTests.scala index 39cf829b92c3..b0878237e449 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FileTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/FileTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.{File => JFile} diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/IfGotoTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/IfGotoTests.scala index aa5cb73bcd2e..20cc5fe5c232 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/IfGotoTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/IfGotoTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, Unknown} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class IfGotoTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ImplementsInterfaceTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ImplementsInterfaceTests.scala index 0238a83d4225..8a1e81d2fbd0 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ImplementsInterfaceTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ImplementsInterfaceTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.File diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/InterfaceTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/InterfaceTests.scala index 4426dea1cb6d..e457ecf7108d 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/InterfaceTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/InterfaceTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/LocalTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/LocalTests.scala index dc9f97880ff4..8762d9b2f9ee 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/LocalTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/LocalTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Local -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Ignore class LocalTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MemberTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MemberTests.scala index 903dfeda8669..b86d2a728c6d 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MemberTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MemberTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Ignore class MemberTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MetaDataTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MetaDataTests.scala index 650fd5eca576..1830f176ce88 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MetaDataTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MetaDataTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataTests extends JimpleCode2CpgFixture { @@ -19,8 +19,8 @@ class MetaDataTests extends JimpleCode2CpgFixture { "should not have any incoming or outgoing edges" in { cpg.metaData.size shouldBe 1 - cpg.metaData.in().l shouldBe List() - cpg.metaData.out().l shouldBe List() + cpg.metaData.in.l shouldBe List() + cpg.metaData.out.l shouldBe List() } } diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodParameterTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodParameterTests.scala index 005df7e16195..953ee19f0e25 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodParameterTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodParameterTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodParameterTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodReturnTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodReturnTests.scala index 676c2e40ff1a..90394c197098 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodReturnTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodReturnTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodReturnTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodTests.scala index a6f3cbbbede7..fa5284f4b101 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/MethodTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/NamespaceBlockTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/NamespaceBlockTests.scala index 4af8a2ae9956..b2b389bdd073 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/NamespaceBlockTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/NamespaceBlockTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NamespaceBlockTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ReflectionTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ReflectionTests.scala index d509ea4e5857..5b50d987c98c 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ReflectionTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/ReflectionTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** Right now reflection is mostly unsupported. This should be extended in later when it is. */ diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SpecialOperatorTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SpecialOperatorTests.scala index 7797f83427ed..3cb0f35d077e 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SpecialOperatorTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SpecialOperatorTests.scala @@ -4,7 +4,7 @@ import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, TypeRef} import io.shiftleft.proto.cpg.Cpg.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SpecialOperatorTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SwitchTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SwitchTests.scala index 41896808acaf..2214335eb27b 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SwitchTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SwitchTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.JumpTarget -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SwitchTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SynchronizedTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SynchronizedTests.scala index fad58de51c90..ce4481faaee6 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SynchronizedTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/SynchronizedTests.scala @@ -2,8 +2,8 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class SynchronizedTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeDeclTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeDeclTests.scala index 57d46e0f2853..78bf627e2ac4 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeDeclTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeDeclTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.File diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeTests.scala index 3dbc9ca13a40..8108fcbef372 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/TypeTests.scala @@ -2,7 +2,7 @@ package io.joern.jimple2cpg.querying import io.joern.jimple2cpg.testfixtures.JimpleCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TypeTests extends JimpleCode2CpgFixture { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/ArrayTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/ArrayTests.scala index b36a9e75e763..a375017251e0 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/ArrayTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/ArrayTests.scala @@ -1,6 +1,6 @@ package io.joern.jimple2cpg.querying.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.jimple2cpg.testfixtures.{JimpleDataFlowCodeToCpgSuite, JimpleDataflowTestCpg} class ArrayTests extends JimpleDataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/FunctionCallTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/FunctionCallTests.scala index 63a49bb0cfaf..942d79905634 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/FunctionCallTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/FunctionCallTests.scala @@ -3,7 +3,7 @@ package io.joern.jimple2cpg.querying.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.jimple2cpg.testfixtures.JimpleDataFlowCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.Operators class FunctionCallTests extends JimpleDataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SemanticTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SemanticTests.scala index 2cf966788c4d..dbf0956dbdae 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SemanticTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SemanticTests.scala @@ -1,15 +1,18 @@ package io.joern.jimple2cpg.querying.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.language.* import io.joern.jimple2cpg.testfixtures.{JimpleDataFlowCodeToCpgSuite, JimpleDataflowTestCpg} import io.joern.dataflowengineoss.semanticsloader.FlowSemantic import io.joern.x2cpg.Defines class SemanticTests - extends JimpleDataFlowCodeToCpgSuite(extraFlows = - List( - FlowSemantic.from("Test.sanitize:java.lang.String(java.lang.String)", List((0, 0), (1, 1))), - FlowSemantic.from("java.nio.file.Paths.get:.*\\(java.lang.String,.*\\)", List.empty, regex = true) + extends JimpleDataFlowCodeToCpgSuite(semantics = + DefaultSemantics().plus( + List( + FlowSemantic.from("Test.sanitize:java.lang.String(java.lang.String)", List((0, 0), (1, 1))), + FlowSemantic.from("java.nio.file.Paths.get:.*\\(java.lang.String,.*\\)", List.empty, regex = true) + ) ) ) { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SwitchTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SwitchTests.scala index 5949eda27a45..2611292db34a 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SwitchTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/querying/dataflow/SwitchTests.scala @@ -1,6 +1,6 @@ package io.joern.jimple2cpg.querying.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.jimple2cpg.testfixtures.JimpleDataFlowCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.Cpg diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleCodeToCpgFixture.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleCodeToCpgFixture.scala index 3d754650d482..ebdca165c635 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleCodeToCpgFixture.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleCodeToCpgFixture.scala @@ -1,6 +1,7 @@ package io.joern.jimple2cpg.testfixtures -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} import io.joern.jimple2cpg.{Config, Jimple2Cpg} import io.joern.x2cpg.X2Cpg @@ -23,9 +24,9 @@ trait Jimple2CpgFrontend extends LanguageFrontend { } } -class JimpleCode2CpgFixture(withOssDataflow: Boolean = false, extraFlows: List[FlowSemantic] = List.empty) - extends Code2CpgFixture(() => new JimpleTestCpg().withOssDataflow(withOssDataflow).withExtraFlows(extraFlows)) - with SemanticCpgTestFixture(extraFlows) {} +class JimpleCode2CpgFixture(withOssDataflow: Boolean = false, semantics: Semantics = DefaultSemantics()) + extends Code2CpgFixture(() => new JimpleTestCpg().withOssDataflow(withOssDataflow).withSemantics(semantics)) + with SemanticCpgTestFixture(semantics) {} class JimpleTestCpg extends DefaultTestCpg with Jimple2CpgFrontend with SemanticTestCpg { diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleDataflowCodeToCpgSuite.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleDataflowCodeToCpgSuite.scala index 8fa238410617..c2baab59cc80 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleDataflowCodeToCpgSuite.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/testfixtures/JimpleDataflowCodeToCpgSuite.scala @@ -1,15 +1,16 @@ package io.joern.jimple2cpg.testfixtures +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.x2cpg.testfixtures.Code2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.LayerCreatorContext -class JimpleDataflowTestCpg(val extraFlows: List[FlowSemantic] = List.empty) extends JimpleTestCpg { +class JimpleDataflowTestCpg(val semantics: Semantics = DefaultSemantics()) extends JimpleTestCpg { implicit val resolver: ICallResolver = NoResolve implicit lazy val engineContext: EngineContext = EngineContext() @@ -17,14 +18,14 @@ class JimpleDataflowTestCpg(val extraFlows: List[FlowSemantic] = List.empty) ext override def applyPasses(): Unit = { super.applyPasses() val context = new LayerCreatorContext(this) - val options = new OssDataFlowOptions(extraFlows = extraFlows) + val options = new OssDataFlowOptions(semantics = semantics) new OssDataFlow(options).run(context) } } -class JimpleDataFlowCodeToCpgSuite(val extraFlows: List[FlowSemantic] = List.empty) - extends Code2CpgFixture(() => new JimpleDataflowTestCpg(extraFlows)) { +class JimpleDataFlowCodeToCpgSuite(val semantics: Semantics = DefaultSemantics()) + extends Code2CpgFixture(() => new JimpleDataflowTestCpg(semantics)) { implicit var context: EngineContext = EngineContext() diff --git a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/unpacking/JarUnpackingTests.scala b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/unpacking/JarUnpackingTests.scala index b36deb4b9189..5f641af9a7b7 100644 --- a/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/unpacking/JarUnpackingTests.scala +++ b/joern-cli/frontends/jimple2cpg/src/test/scala/io/joern/jimple2cpg/unpacking/JarUnpackingTests.scala @@ -12,14 +12,15 @@ import org.scalatest.matchers.should.Matchers.* import org.scalatest.wordspec.AnyWordSpec import java.nio.file.{Files, Path, Paths} +import scala.compiletime.uninitialized import scala.util.{Failure, Success, Try} class JarUnpackingTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { - var recurseCpgs: Map[String, Cpg] = scala.compiletime.uninitialized - var noRecurseCpgs: Map[String, Cpg] = scala.compiletime.uninitialized - var depthsCpgs: Map[String, Cpg] = scala.compiletime.uninitialized - var slippyCpg: Cpg = scala.compiletime.uninitialized + var recurseCpgs: Map[String, Cpg] = uninitialized + var noRecurseCpgs: Map[String, Cpg] = uninitialized + var depthsCpgs: Map[String, Cpg] = uninitialized + var slippyCpg: Cpg = uninitialized override protected def beforeAll(): Unit = { super.beforeAll() diff --git a/joern-cli/frontends/jssrc2cpg/build.sbt b/joern-cli/frontends/jssrc2cpg/build.sbt index f44a65b66d99..2bbc5568250d 100644 --- a/joern-cli/frontends/jssrc2cpg/build.sbt +++ b/joern-cli/frontends/jssrc2cpg/build.sbt @@ -69,16 +69,15 @@ astGenDlTask := { val astGenDir = baseDirectory.value / "bin" / "astgen" astGenBinaryNames.value.foreach { fileName => - DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", astGenDir / fileName) + val file = astGenDir / fileName + DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", file) + // permissions are lost during the download; need to set them manually + file.setExecutable(true, false) } val distDir = (Universal / stagingDirectory).value / "bin" / "astgen" distDir.mkdirs() - IO.copyDirectory(astGenDir, distDir) - - // permissions are lost during the download; need to set them manually - astGenDir.listFiles().foreach(_.setExecutable(true, false)) - distDir.listFiles().foreach(_.setExecutable(true, false)) + IO.copyDirectory(astGenDir, distDir, preserveExecutable = true) } Compile / compile := ((Compile / compile) dependsOn astGenDlTask).value @@ -93,3 +92,7 @@ stage := Def Universal / packageName := name.value Universal / topLevelDirectory := None + +/** write the astgen version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("JS-AstGen-Version") -> astGenVersion.value) diff --git a/joern-cli/frontends/jssrc2cpg/src/main/resources/application.conf b/joern-cli/frontends/jssrc2cpg/src/main/resources/application.conf index caa918ef495a..917d0508240b 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/resources/application.conf +++ b/joern-cli/frontends/jssrc2cpg/src/main/resources/application.conf @@ -1,3 +1,3 @@ jssrc2cpg { - astgen_version: "3.14.0" + astgen_version: "3.21.0" } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala index 22b3ce8120a4..258db2764f8c 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala @@ -18,11 +18,10 @@ import scala.util.Try class JsSrc2Cpg extends X2CpgFrontend[Config] { - private val report: Report = new Report() - def createCpg(config: Config): Try[Cpg] = { withNewEmptyCpg(config.outputPath, config) { (cpg, config) => File.usingTemporaryDirectory("jssrc2cpgOut") { tmpDir => + val report = new Report() val astGenResult = new AstGenRunner(config).execute(tmpDir) val hash = HashUtil.sha256(astGenResult.parsedFiles.map { case (_, file) => File(file).path }) diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/Main.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/Main.scala index 06b270dbc514..4582b898636e 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/Main.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/Main.scala @@ -4,6 +4,7 @@ import io.joern.jssrc2cpg.Frontend.* import io.joern.x2cpg.passes.frontend.{TypeRecoveryParserConfig, XTypeRecovery, XTypeRecoveryConfig} import io.joern.x2cpg.utils.Environment import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser import java.nio.file.Paths @@ -34,14 +35,19 @@ object Frontend { } -object Main extends X2CpgMain(cmdLineParser, new JsSrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new JsSrc2Cpg()) with FrontendHTTPServer[Config, JsSrc2Cpg] { + + override protected def newDefaultConfig(): Config = Config() def run(config: Config, jssrc2cpg: JsSrc2Cpg): Unit = { - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - if (Environment.pathExists(absPath)) { - jssrc2cpg.run(config.withInputPath(absPath)) - } else { - System.exit(1) + if (config.serverMode) { startup() } + else { + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + if (Environment.pathExists(absPath)) { + jssrc2cpg.run(config.withInputPath(absPath)) + } else { + System.exit(1) + } } } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreator.scala index dd6b95bb45c5..a2c81ee3c6ea 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreator.scala @@ -1,24 +1,31 @@ package io.joern.jssrc2cpg.astcreation import io.joern.jssrc2cpg.Config -import io.joern.jssrc2cpg.datastructures.{MethodScope, Scope} +import io.joern.jssrc2cpg.datastructures.MethodScope +import io.joern.jssrc2cpg.datastructures.Scope import io.joern.jssrc2cpg.parser.BabelAst.* import io.joern.jssrc2cpg.parser.BabelJsonParser.ParseResult import io.joern.jssrc2cpg.parser.BabelNodeInfo +import io.joern.x2cpg.Ast +import io.joern.x2cpg.AstCreatorBase +import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.AstNodeBuilder as X2CpgAstNodeBuilder +import io.joern.x2cpg.datastructures.Global import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines -import io.joern.x2cpg.utils.NodeBuilders.{newMethodReturnNode, newModifierNode} -import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode, AstNodeBuilder as X2CpgAstNodeBuilder} -import io.joern.x2cpg.datastructures.Global -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes, NodeTypes} +import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode +import io.joern.x2cpg.utils.NodeBuilders.newModifierNode +import io.shiftleft.codepropertygraph.generated.EvaluationStrategies +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.NewBlock import io.shiftleft.codepropertygraph.generated.nodes.NewFile -import io.shiftleft.codepropertygraph.generated.nodes.NewMethod import io.shiftleft.codepropertygraph.generated.nodes.NewNode import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import io.shiftleft.codepropertygraph.generated.nodes.NewTypeRef -import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder +import org.slf4j.Logger +import org.slf4j.LoggerFactory import ujson.Value import scala.collection.mutable @@ -73,28 +80,13 @@ class AstCreator(val config: Config, val global: Global, val parserResult: Parse } private def createProgramMethod(): Ast = { - val path = parserResult.filename - val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) - val lineNumber = astNodeInfo.lineNumber - val columnNumber = astNodeInfo.columnNumber - val lineNumberEnd = astNodeInfo.lineNumberEnd - val columnNumberEnd = astNodeInfo.columnNumberEnd - val name = Defines.Program - val fullName = s"$path:$name" + val path = parserResult.filename + val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) + val name = Defines.Program + val fullName = s"$path:$name" val programMethod = - NewMethod() - .order(1) - .name(name) - .code(name) - .fullName(fullName) - .filename(path) - .lineNumber(lineNumber) - .lineNumberEnd(lineNumberEnd) - .columnNumber(columnNumber) - .columnNumberEnd(columnNumberEnd) - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(fullName) + methodNode(astNodeInfo, name, name, fullName, None, path, Some(NodeTypes.TYPE_DECL), Some(fullName)).order(1) val functionTypeAndTypeDeclAst = createFunctionTypeAndTypeDeclAst(astNodeInfo, programMethod, methodAstParentStack.head, name, fullName, path) diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreatorHelper.scala index e4c16a4c1892..d5108eaa24c1 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstCreatorHelper.scala @@ -4,12 +4,12 @@ import io.joern.jssrc2cpg.datastructures.* import io.joern.jssrc2cpg.parser.BabelAst.* import io.joern.jssrc2cpg.parser.BabelNodeInfo import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.utils.NodeBuilders.{newClosureBindingNode, newLocalNode} import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies} import io.shiftleft.codepropertygraph.generated.nodes.File.PropertyDefaults -import io.shiftleft.passes.IntervalKeyPool import ujson.Value import scala.collection.{mutable, SortedMap} diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala index e30ee495dc96..da800b7f715a 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala @@ -85,8 +85,10 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { ): Ast = { from match { case Some(value) => + val identNode = identifierNode(declaration, value) + scope.addVariableReference(name, identNode) val call = createFieldAccessCallAst( - identifierNode(declaration, value, Seq.empty), + identNode, createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), declaration.lineNumber, declaration.columnNumber @@ -99,9 +101,11 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { declaration.columnNumber ) case None => + val identNode = identifierNode(declaration, name) + scope.addVariableReference(name, identNode) createAssignmentCallAst( exportCallAst, - Ast(identifierNode(declaration, name)), + Ast(identNode), s"${codeOf(exportCallAst.nodes.head)} = $name", declaration.lineNumber, declaration.columnNumber @@ -262,7 +266,6 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { createExportAssignmentCallAst(name, exportCallAst, assignment, None) } } - setArgumentIndices(declAsts) blockAst(createBlockNode(assignment), declAsts) } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForExpressionsCreator.scala index 01fe0605917d..222655e7fe74 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -3,11 +3,16 @@ package io.joern.jssrc2cpg.astcreation import io.joern.jssrc2cpg.parser.BabelAst.* import io.joern.jssrc2cpg.parser.BabelNodeInfo import io.joern.jssrc2cpg.passes.EcmaBuiltins -import io.joern.x2cpg.frontendspecific.jssrc2cpg.{Defines, GlobalBuiltins} -import io.joern.x2cpg.{Ast, ValidationMode} +import io.joern.x2cpg.Ast +import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Stack.* +import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines +import io.joern.x2cpg.frontendspecific.jssrc2cpg.GlobalBuiltins +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.NewIdentifier import io.shiftleft.codepropertygraph.generated.nodes.NewNode -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators} import scala.util.Try @@ -21,20 +26,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case MemberExpression => code(callee.json("property")) case _ => callee.code } - val callNode = - createStaticCallNode(callExpr.code, callName, fullName, callee.lineNumber, callee.columnNumber) - val argAsts = astForNodes(callExpr.json("arguments").arr.toList) + val callNode = createStaticCallNode(callExpr.code, callName, fullName, callee.lineNumber, callee.columnNumber) + val argAsts = astForNodes(callExpr.json("arguments").arr.toList) callAst(callNode, argAsts) } - private def handleCallNodeArgs( - callExpr: BabelNodeInfo, - receiverAst: Ast, - baseNode: NewNode, - callName: String - ): Ast = { + private case class CallExpressionInfo(receiverAst: Ast, baseNode: NewIdentifier, callName: String) + + private def handleCallNodeArgs(callExpr: BabelNodeInfo, callExpressionInfo: CallExpressionInfo): Ast = { val args = astForNodes(callExpr.json("arguments").arr.toList) - val callNode_ = callNode(callExpr, callExpr.code, callName, DispatchTypes.DYNAMIC_DISPATCH) + val callNode_ = callNode(callExpr, callExpr.code, callExpressionInfo.callName, DispatchTypes.DYNAMIC_DISPATCH) // If the callee is a function itself, e.g. closure, then resolve this locally, if possible callExpr.json.obj .get("callee") @@ -44,7 +45,51 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case _ => None } .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } - callAst(callNode_, args, receiver = Option(receiverAst), base = Option(Ast(baseNode))) + callAst( + callNode_, + args, + receiver = Option(callExpressionInfo.receiverAst), + base = Option(Ast(callExpressionInfo.baseNode)) + ) + } + + private def callExpressionInfoForCallLikeExpr(callLike: BabelNodeInfo): CallExpressionInfo = { + callLike.node match { + case MemberExpression => + val base = createBabelNodeInfo(callLike.json("object")) + val member = createBabelNodeInfo(callLike.json("property")) + base.node match { + case ThisExpression => + val receiverAst = astForNodeWithFunctionReference(callLike.json) + val baseNode = identifierNode(base, base.code).dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariableReference(base.code, baseNode) + CallExpressionInfo(receiverAst, baseNode, member.code) + case Identifier => + val receiverAst = astForNodeWithFunctionReference(callLike.json) + val baseNode = identifierNode(base, base.code) + scope.addVariableReference(base.code, baseNode) + CallExpressionInfo(receiverAst, baseNode, member.code) + case _ => + val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") + val baseTmpNode = identifierNode(base, tmpVarName) + scope.addVariableReference(tmpVarName, baseTmpNode) + val baseAst = astForNodeWithFunctionReference(base.json) + val code = s"(${codeOf(baseTmpNode)} = ${base.code})" + val tmpAssignmentAst = + createAssignmentCallAst(Ast(baseTmpNode), baseAst, code, base.lineNumber, base.columnNumber) + val memberNode = createFieldIdentifierNode(member.code, member.lineNumber, member.columnNumber) + val fieldAccessAst = + createFieldAccessCallAst(tmpAssignmentAst, memberNode, callLike.lineNumber, callLike.columnNumber) + val thisTmpNode = identifierNode(callLike, tmpVarName) + scope.addVariableReference(tmpVarName, thisTmpNode) + CallExpressionInfo(fieldAccessAst, thisTmpNode, member.code) + } + case _ => + val receiverAst = astForNodeWithFunctionReference(callLike.json) + val thisNode = identifierNode(callLike, "this").dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariableReference(thisNode.name, thisNode) + CallExpressionInfo(receiverAst, thisNode, callLike.code) + } } protected def astForCallExpression(callExpr: BabelNodeInfo): Ast = { @@ -53,44 +98,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { if (GlobalBuiltins.builtins.contains(calleeCode)) { createBuiltinStaticCall(callExpr, callee, calleeCode) } else { - val (receiverAst, baseNode, callName) = callee.node match { - case MemberExpression => - val base = createBabelNodeInfo(callee.json("object")) - val member = createBabelNodeInfo(callee.json("property")) - base.node match { - case ThisExpression => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code).dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) - case Identifier => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) - case _ => - val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") - val baseTmpNode = identifierNode(base, tmpVarName) - scope.addVariableReference(tmpVarName, baseTmpNode) - val baseAst = astForNodeWithFunctionReference(base.json) - val code = s"(${codeOf(baseTmpNode)} = ${base.code})" - val tmpAssignmentAst = - createAssignmentCallAst(Ast(baseTmpNode), baseAst, code, base.lineNumber, base.columnNumber) - val memberNode = createFieldIdentifierNode(member.code, member.lineNumber, member.columnNumber) - val fieldAccessAst = - createFieldAccessCallAst(tmpAssignmentAst, memberNode, callee.lineNumber, callee.columnNumber) - val thisTmpNode = identifierNode(callee, tmpVarName) - scope.addVariableReference(tmpVarName, thisTmpNode) - - (fieldAccessAst, thisTmpNode, member.code) - } - case _ => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val thisNode = identifierNode(callee, "this").dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariableReference(thisNode.name, thisNode) - (receiverAst, thisNode, calleeCode) - } - handleCallNodeArgs(callExpr, receiverAst, baseNode, callName) + val callExpressionInfo = callExpressionInfoForCallLikeExpr(callee) + handleCallNodeArgs(callExpr, callExpressionInfo) } } @@ -114,9 +123,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { diffGraph.addEdge(localAstParentStack.head, localTmpAllocNode, EdgeTypes.AST) scope.addVariableReference(tmpAllocName, tmpAllocNode1) - val allocCallNode = - callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) - + val allocCallNode = callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) val assignmentTmpAllocCallNode = createAssignmentCallAst( tmpAllocNode1, @@ -127,11 +134,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ) val tmpAllocNode2 = identifierNode(newExpr, tmpAllocName) - - val receiverNode = astForNodeWithFunctionReference(callee) - - val callAst = handleCallNodeArgs(newExpr, receiverNode, tmpAllocNode2, Defines.OperatorsNew) - + val receiverNode = astForNodeWithFunctionReference(callee) + val callAst = handleCallNodeArgs(newExpr, CallExpressionInfo(receiverNode, tmpAllocNode2, Defines.OperatorsNew)) val tmpAllocReturnNode = Ast(identifierNode(newExpr, tmpAllocName)) scope.popScope() @@ -217,31 +221,21 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForBinaryExpression(logicalExpr) protected def astForTSNonNullExpression(nonNullExpr: BabelNodeInfo): Ast = { - val op = Operators.notNullAssert - val callNode_ = - callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) + val op = Operators.notNullAssert + val callNode_ = callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) callAst(callNode_, argAsts) } protected def astForCastExpression(castExpr: BabelNodeInfo): Ast = { - val op = Operators.cast - val lhsNode = castExpr.json("typeAnnotation") - val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) - typeFor(castExpr) match { - case tpe if GlobalBuiltins.builtins.contains(tpe) || Defines.isBuiltinType(tpe) => - val lhsAst = Ast(literalNode(castExpr, code(lhsNode), Option(tpe))) - val node = - callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH).dynamicTypeHintFullName(Seq(tpe)) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - case t => - val possibleTypes = Seq(t) - val lhsAst = Ast(literalNode(castExpr, code(lhsNode), None).possibleTypes(possibleTypes)) - val node = callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH).possibleTypes(possibleTypes) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - } + val op = Operators.cast + val lhsNode = castExpr.json("typeAnnotation") + val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) + val possibleTypes = Seq(typeFor(castExpr)) + val lhsAst = Ast(literalNode(castExpr, code(lhsNode), None).possibleTypes(possibleTypes)) + val node = callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH).possibleTypes(possibleTypes) + val argAsts = List(lhsAst, rhsAst) + callAst(node, argAsts) } protected def astForBinaryExpression(binExpr: BabelNodeInfo): Ast = { @@ -340,15 +334,14 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForAwaitExpression(awaitExpr: BabelNodeInfo): Ast = { - val node = - callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) + val node = callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) val argAsts = List(astForNodeWithFunctionReference(awaitExpr.json("argument"))) callAst(node, argAsts) } - protected def astForArrayExpression(arrExpr: BabelNodeInfo): Ast = { + protected def astForArrayExpression(arrExpr: BabelNodeInfo, elementsKey: String = "elements"): Ast = { val MAX_INITIALIZERS = 1000 - val elementsJsons = Try(arrExpr.json("elements").arr).toOption.toList.flatten + val elementsJsons = Try(arrExpr.json(elementsKey).arr).toOption.toList.flatten val elements = elementsJsons.slice(0, MAX_INITIALIZERS) if (elements.isEmpty) { Ast( @@ -379,7 +372,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val elementNodeInfo = createBabelNodeInfo(element) val elementLineNumber = elementNodeInfo.lineNumber val elementColumnNumber = elementNodeInfo.columnNumber - val elementCode = elementNodeInfo.code val elementNode = elementNodeInfo.node match { case RestElement => val arg1Ast = Ast(identifierNode(arrExpr, tmpName)) @@ -388,6 +380,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForNodeWithFunctionReference(element) } + val elementCode = elementNode.root.map(codeOf).getOrElse(elementNodeInfo.code) val pushCallNode = callNode(elementNodeInfo, s"$tmpName.push($elementCode)", "", DispatchTypes.DYNAMIC_DISPATCH) @@ -416,14 +409,34 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } + private def handleTemplateExpressionArgs(templateExpr: BabelNodeInfo, callExpressionInfo: CallExpressionInfo): Ast = { + val expressionArgs = templateExpr.json("quasi")("expressions").arr.toList.map(astForNodeWithFunctionReference) + val quasisArg = astForArrayExpression(createBabelNodeInfo(templateExpr.json("quasi")), "quasis") + val callNode_ = + callNode(templateExpr, templateExpr.code, callExpressionInfo.callName, DispatchTypes.DYNAMIC_DISPATCH) + // If the callee is a function itself, e.g. closure, then resolve this locally, if possible + templateExpr.json.obj + .get("callee") + .map(createBabelNodeInfo) + .flatMap { + case callee if callee.node.isInstanceOf[FunctionLike] => functionNodeToNameAndFullName.get(callee) + case _ => None + } + .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } + callAst( + callNode_, + quasisArg +: expressionArgs, + receiver = Option(callExpressionInfo.receiverAst), + base = Option(Ast(callExpressionInfo.baseNode)) + ) + } + + /** Lowering from expressions like, x`a ${1+1} b` to x(["a ", " b"], 1+1) + */ def astForTemplateExpression(templateExpr: BabelNodeInfo): Ast = { - val argumentAst = astForNodeWithFunctionReference(templateExpr.json("quasi")) - val callName = code(templateExpr.json("tag")) - val callCode = s"$callName(${codeOf(argumentAst.nodes.head)})" - val templateExprCall = - callNode(templateExpr, callCode, callName, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(templateExprCall, argAsts) + val callee = createBabelNodeInfo(templateExpr.json("tag")) + val callExpressionInfo = callExpressionInfoForCallLikeExpr(callee) + handleTemplateExpressionArgs(templateExpr, callExpressionInfo) } protected def astForObjectExpression(objExpr: BabelNodeInfo): Ast = { diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForFunctionsCreator.scala index bcaea8374b4c..c357f76de027 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -5,7 +5,7 @@ import io.joern.jssrc2cpg.parser.BabelAst.* import io.joern.jssrc2cpg.parser.BabelNodeInfo import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines -import io.joern.x2cpg.utils.NodeBuilders.{newBindingNode, newModifierNode} +import io.joern.x2cpg.utils.NodeBuilders.newModifierNode import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.{Identifier as _, *} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, EvaluationStrategies, ModifierTypes} @@ -325,10 +325,13 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } protected def astForTSDeclareFunction(func: BabelNodeInfo): Ast = { - val functionNode = createMethodDefinitionNode(func) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(getParentTypeDecl, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) + val functionNode = createMethodDefinitionNode(func) + val tpe = typeFor(func) + val possibleTypes = Seq(tpe) + val typeFullName = if (Defines.isBuiltinType(tpe)) tpe else Defines.Any + val memberNode_ = memberNode(func, functionNode.name, func.code, typeFullName, Seq(functionNode.fullName)) + .possibleTypes(possibleTypes) + diffGraph.addEdge(getParentTypeDecl, memberNode_, EdgeTypes.AST) addModifier(functionNode, func.json) Ast(functionNode) } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForTypesCreator.scala index ccb700cf3f0d..c20227c61512 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstForTypesCreator.scala @@ -6,9 +6,8 @@ import io.joern.jssrc2cpg.parser.BabelNodeInfo import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines -import io.joern.x2cpg.utils.NodeBuilders.newBindingNode import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, ModifierTypes, Operators} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, ModifierTypes, Operators, PropertyNames} import ujson.Value import scala.util.Try @@ -30,7 +29,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: } else nameTpe val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val astParentFullName = methodAstParentStack.head.properties(PropertyNames.FULL_NAME).toString val aliasTypeDeclNode = typeDeclNode(alias, aliasName, aliasFullName, parserResult.filename, alias.code, astParentType, astParentFullName) @@ -94,17 +93,19 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: forElem: BabelNodeInfo, methodBlockContent: List[Ast] = List.empty ): MethodAst = { + val fakeStartEnd = + s""" + | "start": ${start(forElem.json).getOrElse(-1)}, + | "end": ${end(forElem.json).getOrElse(-1)} + |""".stripMargin + val fakeConstructorCode = s"""{ | "type": "ClassMethod", + | $fakeStartEnd, | "key": { | "type": "Identifier", | "name": "constructor", - | "loc": { - | "start": { - | "line": ${forElem.lineNumber.getOrElse(-1)}, - | "column": ${forElem.columnNumber.getOrElse(-1)} - | } - | } + | $fakeStartEnd | }, | "kind": "constructor", | "id": null, @@ -180,18 +181,12 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: val typeFullName = if (Defines.isBuiltinType(tpe)) tpe else Defines.Any val memberNode_ = nodeInfo.node match { case TSDeclareMethod | TSDeclareFunction => - val function = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + val function = createMethodDefinitionNode(nodeInfo) addModifier(function, nodeInfo.json) memberNode(nodeInfo, function.name, nodeInfo.code, typeFullName, Seq(function.fullName)) .possibleTypes(possibleTypes) case ClassMethod | ClassPrivateMethod => - val function = createMethodAstAndNode(nodeInfo).methodNode - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + val function = createMethodAstAndNode(nodeInfo).methodNode addModifier(function, nodeInfo.json) memberNode(nodeInfo, function.name, nodeInfo.code, typeFullName, Seq(function.fullName)) .possibleTypes(possibleTypes) @@ -237,7 +232,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: registerType(typeFullName) val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val astParentFullName = methodAstParentStack.head.properties(PropertyNames.FULL_NAME).toString val typeDeclNode_ = typeDeclNode( tsEnum, @@ -319,7 +314,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: registerType(typeFullName) val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val astParentFullName = methodAstParentStack.head.properties(PropertyNames.FULL_NAME).toString val superClass = Try(createBabelNodeInfo(clazz.json("superClass")).code).toOption.toSeq val implements = Try(clazz.json("implements").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten @@ -474,7 +469,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: registerType(typeFullName) val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val astParentFullName = methodAstParentStack.head.properties(PropertyNames.FULL_NAME).toString val extendz = Try(tsInterface.json("extends").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten @@ -500,9 +495,9 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: val constructorNode = interfaceConstructor(typeName, tsInterface) diffGraph.addEdge(constructorNode, NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), EdgeTypes.AST) - val constructorBindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, constructorBindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(constructorBindingNode, constructorNode, EdgeTypes.REF) + val memberNode_ = + memberNode(tsInterface, constructorNode.name, constructorNode.code, typeFullName, Seq(constructorNode.fullName)) + diffGraph.addEdge(typeDeclNode_, memberNode_, EdgeTypes.AST) val interfaceBodyElements = classMembers(tsInterface, withConstructor = false) @@ -514,9 +509,6 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: val memberNodes = nodeInfo.node match { case TSCallSignatureDeclaration | TSMethodSignature => val functionNode = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) addModifier(functionNode, nodeInfo.json) Seq( memberNode(nodeInfo, functionNode.name, nodeInfo.code, typeFullName, Seq(functionNode.fullName)) diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstNodeBuilder.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstNodeBuilder.scala index 2ef960276b6a..2c6ccdb5dc6c 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstNodeBuilder.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/AstNodeBuilder.scala @@ -6,8 +6,7 @@ import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, PropertyNames} trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator => protected def createMethodReturnNode(func: BabelNodeInfo): NewMethodReturn = { @@ -251,7 +250,7 @@ trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstC registerType(methodFullName) val astParentType = parentNode.label - val astParentFullName = parentNode.properties("FULL_NAME").toString + val astParentFullName = parentNode.properties(PropertyNames.FULL_NAME).toString val functionTypeDeclNode = typeDeclNode( node, diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/TypeHelper.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/TypeHelper.scala index 744fde0d8a05..726106fa22ad 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/TypeHelper.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/astcreation/TypeHelper.scala @@ -1,6 +1,6 @@ package io.joern.jssrc2cpg.astcreation -import io.joern.jssrc2cpg.parser.BabelAst._ +import io.joern.jssrc2cpg.parser.BabelAst.* import io.joern.jssrc2cpg.parser.BabelNodeInfo import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelAst.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelAst.scala index 2513121ab491..cae2a0ef5705 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelAst.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelAst.scala @@ -206,6 +206,7 @@ object BabelAst { object TSIndexSignature extends BabelNode object TSIndexedAccessType extends TSType object TSInferType extends TSType + object TSInstantiationExpression extends Expression object TSInterfaceBody extends BabelNode object TSInterfaceDeclaration extends BabelNode object TSIntersectionType extends TSType diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelJsonParser.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelJsonParser.scala index d03769e86548..5fe665af993a 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelJsonParser.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/parser/BabelJsonParser.scala @@ -4,7 +4,9 @@ import io.joern.x2cpg.astgen.BaseParserResult import io.shiftleft.utils.IOUtils import ujson.Value.Value -import java.nio.file.{Path, Paths} +import java.nio.file.Path +import java.nio.file.Paths +import scala.util.Try object BabelJsonParser { @@ -13,25 +15,41 @@ object BabelJsonParser { fullPath: String, json: Value, fileContent: String, - typeMap: Map[Int, String] + typeMap: Map[Int, String], + fileLoc: Int ) extends BaseParserResult - def readFile(rootPath: Path, file: Path): ParseResult = { - val typeMapPath = Paths.get(file.toString.replace(".json", ".typemap")) - val typeMap = if (typeMapPath.toFile.exists()) { + private def loadTypeMap(file: Path): Try[Map[Int, String]] = Try { + val typeMapPathString = file.toString.replaceAll("\\.[^.]*$", "") + ".typemap" + val typeMapPath = Paths.get(typeMapPathString) + if (typeMapPath.toFile.exists()) { val typeMapJsonContent = IOUtils.readEntireFile(typeMapPath) val typeMapJson = ujson.read(typeMapJsonContent) typeMapJson.obj.map { case (k, v) => k.toInt -> v.str }.toMap } else { - Map.empty[Int, String] + Map.empty } + } + + private def loadJson(file: Path): Try[Value] = Try { + val jsonContent = IOUtils.readEntireFile(file) + ujson.read(jsonContent) + } - val jsonContent = IOUtils.readEntireFile(file) - val json = ujson.read(jsonContent) + private def generateParserResult(rootPath: Path, json: Value, typeMap: Map[Int, String]): Try[ParseResult] = Try { val filename = json("relativeName").str val fullPath = Paths.get(rootPath.toString, filename) val sourceFileContent = IOUtils.readEntireFile(fullPath) - ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap) + val fileLoc = sourceFileContent.lines().count().toInt + ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap, fileLoc) + } + + def readFile(rootPath: Path, file: Path): Try[ParseResult] = { + val typeMap = loadTypeMap(file).getOrElse(Map.empty) + for { + json <- loadJson(file) + parseResult <- generateParserResult(rootPath, json, typeMap) + } yield parseResult } } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/AstCreationPass.scala index f9664be8c0c4..a66ec0913efa 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/AstCreationPass.scala @@ -7,15 +7,19 @@ import io.joern.jssrc2cpg.utils.AstGenRunner.AstGenRunnerResult import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Global import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines -import io.joern.x2cpg.utils.{Report, TimeUtils} +import io.joern.x2cpg.utils.Report +import io.joern.x2cpg.utils.TimeUtils import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.passes.ForkJoinParallelCpgPass import io.shiftleft.utils.IOUtils -import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.Logger +import org.slf4j.LoggerFactory import java.nio.file.Paths -import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* +import scala.util.Failure +import scala.util.Success +import scala.util.Try class AstCreationPass(cpg: Cpg, astGenRunnerResult: AstGenRunnerResult, config: Config, report: Report = new Report())( implicit withSchemaValidation: ValidationMode @@ -45,20 +49,26 @@ class AstCreationPass(cpg: Cpg, astGenRunnerResult: AstGenRunnerResult, config: override def runOnPart(diffGraph: DiffGraphBuilder, input: (String, String)): Unit = { val (rootPath, jsonFilename) = input + val parseResultMaybe = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) val ((gotCpg, filename), duration) = TimeUtils.time { - val parseResult = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) - val fileLOC = IOUtils.readLinesInFile(Paths.get(parseResult.fullPath)).size - report.addReportInfo(parseResult.filename, fileLOC, parsed = true) - Try { - val localDiff = new AstCreator(config, global, parseResult).createAst() - diffGraph.absorb(localDiff) - } match { + parseResultMaybe match { + case Success(parseResult) => + report.addReportInfo(parseResult.filename, parseResult.fileLoc, parsed = true) + Try { + val localDiff = new AstCreator(config, global, parseResult).createAst() + diffGraph.absorb(localDiff) + } match { + case Failure(exception) => + logger.warn(s"Failed to generate a CPG for: '${parseResult.fullPath}'", exception) + (false, parseResult.filename) + case Success(_) => + logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") + (true, parseResult.filename) + } case Failure(exception) => - logger.warn(s"Failed to generate a CPG for: '${parseResult.fullPath}'", exception) - (false, parseResult.filename) - case Success(_) => - logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") - (true, parseResult.filename) + val pathOfFailedFile = jsonFilename.replaceAll("\\.[^.]*$", "") + logger.warn(s"Failed to read json parse result for: '$pathOfFailedFile'", exception) + (false, pathOfFailedFile) } } report.updateReport(filename, cpg = gotCpg, duration) diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportsPass.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportsPass.scala index a3aa62569e84..2ce72c78ef45 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportsPass.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportsPass.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.X2Cpg import io.joern.x2cpg.passes.frontend.XImportsPass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment /** This pass creates `IMPORT` nodes by looking for calls to `require`. `IMPORT` nodes are linked to existing dependency diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeNodePass.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeNodePass.scala index b7eb8ea08bea..032b34655a4f 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeNodePass.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeNodePass.scala @@ -4,14 +4,13 @@ import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines import io.joern.x2cpg.passes.frontend.TypeNodePass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.language.* -import io.shiftleft.passes.KeyPool import scala.collection.mutable object JavaScriptTypeNodePass { - def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) { + def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg): TypeNodePass = { + new TypeNodePass(registeredTypes, cpg, getTypesFromCpg = false) { override def fullToShortName(typeName: String): String = { typeName match { diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/utils/AstGenRunner.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/utils/AstGenRunner.scala index b99b9c2c6bf6..eab0508d0aad 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/utils/AstGenRunner.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/utils/AstGenRunner.scala @@ -1,12 +1,13 @@ package io.joern.jssrc2cpg.utils import better.files.File +import com.typesafe.config.ConfigFactory import io.joern.jssrc2cpg.Config import io.joern.jssrc2cpg.preprocessing.EjsPreprocessor import io.joern.x2cpg.SourceFiles -import io.joern.x2cpg.utils.{Environment, ExternalCommand} +import io.joern.x2cpg.utils.Environment +import io.joern.x2cpg.utils.ExternalCommand import io.shiftleft.utils.IOUtils -import com.typesafe.config.ConfigFactory import org.slf4j.LoggerFactory import versionsort.VersionHelper @@ -23,8 +24,6 @@ object AstGenRunner { private val LineLengthThreshold: Int = 10000 - private val NODE_OPTIONS: Map[String, String] = Map("NODE_OPTIONS" -> "--max-old-space-size=8192") - private val TypeDefinitionFileExtensions = List(".t.ts", ".d.ts") private val MinifiedPathRegex: Regex = ".*([.-]min\\..*js|bundle\\.js)".r @@ -126,7 +125,7 @@ object AstGenRunner { val astGenCommand = path.getOrElse("astgen") val localPath = path.flatMap(File(_).parentOption.map(_.pathAsString)).getOrElse(".") val debugMsgPath = path.getOrElse("PATH") - ExternalCommand.run(s"$astGenCommand --version", localPath).toOption.map(_.mkString.strip()) match { + ExternalCommand.run(Seq(astGenCommand, "--version"), localPath).successOption.map(_.mkString.strip()) match { case Some(installedVersion) if installedVersion != "unknown" && Try(VersionHelper.compare(installedVersion, astGenVersion)).toOption.getOrElse(-1) >= 0 => @@ -173,9 +172,22 @@ object AstGenRunner { class AstGenRunner(config: Config) { - import io.joern.jssrc2cpg.utils.AstGenRunner._ + import io.joern.jssrc2cpg.utils.AstGenRunner.* - private val executableArgs = if (!config.tsTypes) " --no-tsTypes" else "" + private val executableArgs = { + val tsArgs = if (!config.tsTypes) Seq("--no-tsTypes") else Seq.empty + val ignoredFilesRegex = if (config.ignoredFilesRegex.toString().nonEmpty) { + Seq("--exclude-regex", config.ignoredFilesRegex.toString()) + } else { + Seq.empty + } + val ignoreFileArgs = if (config.ignoredFiles.nonEmpty) { + Seq("--exclude-file") ++ config.ignoredFiles.map(f => s"\"$f\"") + } else { + Seq.empty + } + tsArgs ++ ignoredFilesRegex ++ ignoreFileArgs + } private def skippedFiles(astGenOut: List[String]): List[String] = { val skipped = astGenOut.collect { @@ -258,15 +270,23 @@ class AstGenRunner(config: Config) { private def filterFiles(files: List[String], out: File): List[String] = { files.filter { file => - file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match { - // We are not interested in JS / TS type definition files at this stage. - // TODO: maybe we can enable that later on and use the type definitions there - // for enhancing the CPG with additional type information for functions - case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false - case filePath if isIgnoredByUserConfig(filePath) => false - case filePath if isIgnoredByDefault(filePath) => false - case filePath if isTranspiledFile(filePath) => false - case _ => true + Try { + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match { + // We are not interested in JS / TS type definition files at this stage. + // TODO: maybe we can enable that later on and use the type definitions there + // for enhancing the CPG with additional type information for functions + case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false + case filePath if isIgnoredByUserConfig(filePath) => false + case filePath if isIgnoredByDefault(filePath) => false + case filePath if isTranspiledFile(filePath) => false + case _ => true + } + } match { + case Success(result) => result + case Failure(exception) => + // Log the exception for debugging purposes + logger.error("An error occurred while processing the file path during filtering file stage : ", exception) + false } } } @@ -297,39 +317,61 @@ class AstGenRunner(config: Config) { } val result = - ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", out.toString(), extraEnv = NODE_OPTIONS) + ExternalCommand.run((astGenCommand +: executableArgs) ++ Seq("-t", "ts", "-o", out.toString), out.toString()) val jsons = SourceFiles.determine(out.toString(), Set(".json")) jsons.foreach { jsonPath => - val jsonFile = File(jsonPath) - val jsonContent = IOUtils.readEntireFile(jsonFile.path) - val json = ujson.read(jsonContent) - val fileName = json("fullName").str - val newFileName = fileName.patch(fileName.lastIndexOf(".js"), ".ejs", 3) - json("relativeName") = newFileName - json("fullName") = newFileName + val jsonFile = File(jsonPath) + val jsonContent = IOUtils.readEntireFile(jsonFile.path) + val json = ujson.read(jsonContent) + val fullName = json("fullName").str + val relativeName = json("relativeName").str + val newFullName = fullName.patch(fullName.lastIndexOf(".js"), ".ejs", 3) + val newRelativeName = relativeName.patch(relativeName.lastIndexOf(".js"), ".ejs", 3) + json("relativeName") = newRelativeName + json("fullName") = newFullName jsonFile.writeText(json.toString()) } tmpJsFiles.foreach(_.delete()) - result + result.toTry } private def ejsFiles(in: File, out: File): Try[Seq[String]] = { - val files = SourceFiles.determine(in.pathAsString, Set(".ejs")) + val files = + SourceFiles.determine( + in.pathAsString, + Set(".ejs"), + ignoredDefaultRegex = Some(AstGenDefaultIgnoreRegex), + ignoredFilesRegex = Some(config.ignoredFilesRegex), + ignoredFilesPath = Some(config.ignoredFiles) + ) if (files.nonEmpty) processEjsFiles(in, out, files) else Success(Seq.empty) } private def vueFiles(in: File, out: File): Try[Seq[String]] = { - val files = SourceFiles.determine(in.pathAsString, Set(".vue")) - if (files.nonEmpty) - ExternalCommand.run(s"$astGenCommand$executableArgs -t vue -o $out", in.toString(), extraEnv = NODE_OPTIONS) - else Success(Seq.empty) + val files = SourceFiles.determine( + in.pathAsString, + Set(".vue"), + ignoredDefaultRegex = Some(AstGenDefaultIgnoreRegex), + ignoredFilesRegex = Some(config.ignoredFilesRegex), + ignoredFilesPath = Some(config.ignoredFiles) + ) + if (files.nonEmpty) { + ExternalCommand + .run((astGenCommand +: executableArgs) ++ Seq("-t", "vue", "-o", out.toString), in.toString()) + .toTry + } else { + Success(Seq.empty) + } } - private def jsFiles(in: File, out: File): Try[Seq[String]] = - ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", in.toString(), extraEnv = NODE_OPTIONS) + private def jsFiles(in: File, out: File): Try[Seq[String]] = { + ExternalCommand + .run((astGenCommand +: executableArgs) ++ Seq("-t", "ts", "-o", out.toString), in.toString()) + .toTry + } private def runAstGenNative(in: File, out: File): Try[Seq[String]] = for { ejsResult <- ejsFiles(in, out) @@ -360,7 +402,9 @@ class AstGenRunner(config: Config) { AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) case Failure(f) => logger.error("\t- running astgen failed!", f) - AstGenRunnerResult() + val parsed = checkParsedFiles(filterFiles(SourceFiles.determine(out.toString(), Set(".json")), out), in) + val skipped = List.empty + AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/config/ConfigTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/config/ConfigTests.scala index 1c5cb963a59e..671ccb26109f 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/config/ConfigTests.scala @@ -22,7 +22,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/dataflow/DataflowTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/dataflow/DataflowTests.scala index 24fde502b8a4..6696b70e20de 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/dataflow/DataflowTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/dataflow/DataflowTests.scala @@ -1,11 +1,11 @@ package io.joern.jssrc2cpg.dataflow -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.CfgNode -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DataflowTests extends DataFlowCodeToCpgSuite { @@ -461,7 +461,7 @@ class DataflowTests extends DataFlowCodeToCpgSuite { cpg.call .code("bar.*") .outE(EdgeTypes.REACHING_DEF) - .count(_.inNode() == cpg.ret.head) shouldBe 1 + .count(_.dst == cpg.ret.head) shouldBe 1 } "Flow from outer params to inner params" in { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromContentTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromContentTests.scala index 6a5fc2274375..739b893f0b59 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromContentTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromContentTests.scala @@ -40,18 +40,22 @@ class CodeDumperFromContentTests extends JsSrc2CpgSuite { | var x = foo(param1); |}""".stripMargin - val cpg = code( - s""" + val fullCode = s""" |// A comment |$myFuncContent - |""".stripMargin, - "index.js" - ).withConfig(Config().withDisableFileContent(false)) + |""".stripMargin + + val cpg = code(fullCode, "index.js").withConfig(Config().withDisableFileContent(false)) "allow one to dump a method node's source code from `Method.content`" in { val List(content) = cpg.method.nameExact("my_func").content.l content shouldBe myFuncContent } + + "allow one to dump the :program method node's source code from `Method.content`" in { + val List(content) = cpg.method.nameExact(":program").content.l + content shouldBe fullCode + } } "code from typedecl content" should { @@ -73,6 +77,11 @@ class CodeDumperFromContentTests extends JsSrc2CpgSuite { val List(content) = cpg.typeDecl.nameExact("Foo").content.l content shouldBe myClassContent } + + "allow one to dump the method node's source code from `Method.content`" in { + val List(content) = cpg.method.nameExact("").content.l + content shouldBe myClassContent + } } "content with UTF8 characters" should { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromFileTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromFileTests.scala index 9c27a157a302..ede0ff1342a9 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromFileTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/CodeDumperFromFileTests.scala @@ -3,7 +3,7 @@ package io.joern.jssrc2cpg.io import better.files.File import io.joern.jssrc2cpg.testfixtures.JsSrc2CpgSuite import io.shiftleft.semanticcpg.codedumper.CodeDumper -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/JsSrc2CpgHTTPServerTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/JsSrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..02559d45de34 --- /dev/null +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/io/JsSrc2CpgHTTPServerTests.scala @@ -0,0 +1,83 @@ +package io.joern.jssrc2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable + +class JsSrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("jssrc2cpgTestsHttpTest") + val file = dir / "main.js" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |function main$indexStr() { + | console.log("Hello World!"); + |} + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.jssrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.jssrc2cpg.Main.stop() + } + + "Using jssrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("jssrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("""console.log("Hello World!")""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("jssrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain(s"main$index") + cpg.call.code.l should contain("""console.log("Hello World!")""") + } + } + } + } + +} diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/CallLinkerPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/CallLinkerPassTests.scala index 0e31d1703470..1f6d34a4f8ef 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/CallLinkerPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/CallLinkerPassTests.scala @@ -3,7 +3,7 @@ package io.joern.jssrc2cpg.passes import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallLinkerPassTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConfigPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConfigPassTests.scala index 87a0e66dad5e..3d2f165406e6 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConfigPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConfigPassTests.scala @@ -4,7 +4,7 @@ import better.files.File import io.joern.jssrc2cpg.Config import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConstClosurePassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConstClosurePassTests.scala index 6abed767948d..f1acaba369fa 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConstClosurePassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ConstClosurePassTests.scala @@ -1,7 +1,7 @@ package io.joern.jssrc2cpg.passes import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ConstClosurePassTests extends DataFlowCodeToCpgSuite { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/DomPassTestsHelper.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/DomPassTestsHelper.scala index dc9ac7cc858c..98433e4d23ef 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/DomPassTestsHelper.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/DomPassTestsHelper.scala @@ -3,7 +3,7 @@ package io.joern.jssrc2cpg.passes import io.shiftleft.codepropertygraph.generated.nodes.Expression import io.shiftleft.codepropertygraph.generated.nodes.TemplateDom import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.apache.commons.lang3.StringUtils trait DomPassTestsHelper { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/InheritanceFullNamePassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/InheritanceFullNamePassTests.scala index cfccc1632b34..f580bf8cb688 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/InheritanceFullNamePassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/InheritanceFullNamePassTests.scala @@ -1,7 +1,7 @@ package io.joern.jssrc2cpg.passes import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File import scala.annotation.nowarn diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsMetaDataPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsMetaDataPassTests.scala index 399723cdc93c..bba23114b188 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsMetaDataPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsMetaDataPassTests.scala @@ -18,11 +18,11 @@ class JsMetaDataPassTests extends AnyWordSpec with Matchers with Inside { new JavaScriptMetaDataPass(cpg, "somehash", "").createAndApply() "create exactly 1 node" in { - cpg.graph.V.asScala.size shouldBe 1 + cpg.graph.allNodes.size shouldBe 1 } "create no edges" in { - cpg.graph.E.asScala.size shouldBe 0 + cpg.graph.allNodes.outE.size shouldBe 0 } "create a metadata node with correct language" in { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/RequirePassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/RequirePassTests.scala index 672a73ca0f65..b354a9942b94 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/RequirePassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/RequirePassTests.scala @@ -1,8 +1,8 @@ package io.joern.jssrc2cpg.passes -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.language.* import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/DependencyAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/DependencyAstCreationPassTests.scala index daa810ae9689..1f265204c966 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/DependencyAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/DependencyAstCreationPassTests.scala @@ -3,7 +3,7 @@ package io.joern.jssrc2cpg.passes.ast import io.joern.jssrc2cpg.testfixtures.AstJsSrc2CpgSuite import io.joern.x2cpg.layers.Base import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DependencyAstCreationPassTests extends AstJsSrc2CpgSuite { @@ -49,6 +49,12 @@ class DependencyAstCreationPassTests extends AstJsSrc2CpgSuite { } "AST generation for dependencies" should { + "reference identifiers correctly" in { + val cpg = code("export const foo = bar();") + cpg.local.name("foo").referencingIdentifiers.size shouldBe 2 + cpg.identifier.name("foo").size shouldBe 2 + } + "have no dependencies if none are declared at all" in { val cpg = code("var x = 1;") cpg.dependency.l.size shouldBe 0 @@ -335,7 +341,7 @@ class DependencyAstCreationPassTests extends AstJsSrc2CpgSuite { |export = function () {}; // anonymous |export = class ClassA {}; |""".stripMargin) - cpg.local.code.l shouldBe List("foo", "bar", "func", "0") + cpg.local.code.l shouldBe List("foo", "bar", "func", "0", "ClassA") cpg.typeDecl.name.l should contain allElementsOf List("func", "ClassA") cpg.assignment.code.l shouldBe List( "var foo = 1", diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/MixedAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/MixedAstCreationPassTests.scala index be5e156a9e4b..7653f019c727 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/MixedAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/MixedAstCreationPassTests.scala @@ -4,8 +4,8 @@ import io.joern.jssrc2cpg.testfixtures.AstJsSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.EvaluationStrategies import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.{ClosureBinding, MethodParameterIn} +import io.shiftleft.semanticcpg.language.* class MixedAstCreationPassTests extends AstJsSrc2CpgSuite { @@ -285,7 +285,7 @@ class MixedAstCreationPassTests extends AstJsSrc2CpgSuite { val List(fooLocalY) = fooBlock.astChildren.isLocal.nameExact("y").l val List(barRef) = fooBlock.astChildren.isCall.astChildren.isMethodRef.l - val List(closureBindForY, closureBindForX) = barRef.captureOut.l + val List(closureBindForY, closureBindForX) = barRef.captureOut.cast[ClosureBinding].l closureBindForX.closureOriginalName shouldBe Option("x") closureBindForX.closureBindingId shouldBe Option("Test0.js::program:foo:bar:x") diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/SimpleAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/SimpleAstCreationPassTests.scala index f1bab7a456f8..e642bd021531 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/SimpleAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/SimpleAstCreationPassTests.scala @@ -269,31 +269,54 @@ class SimpleAstCreationPassTests extends AstJsSrc2CpgSuite { argument3.code shouldBe "\"\"" } - "have correct structure for tagged runtime node" in { - val cpg = code(s"String.raw`../$${42}\\..`") + "have correct structure for tagged runtime node with simple tag expression" in { + val cpg = code(s"x`a $${1+1} b`") val List(method) = cpg.method.nameExact(":program").l val List(methodBlock) = method.astChildren.isBlock.l val List(rawCall) = methodBlock.astChildren.isCall.l - rawCall.code shouldBe s"String.raw(${Operators.formatString}(\"../\", 42, \"\\..\"))" - - val List(runtimeCall) = rawCall.astChildren.isCall.nameExact(Operators.formatString).l - runtimeCall.order shouldBe 1 - runtimeCall.argumentIndex shouldBe 1 - runtimeCall.code shouldBe s"${Operators.formatString}(\"../\", 42, \"\\..\")" + rawCall.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + rawCall.name shouldBe "x" + rawCall.receiver.isIdentifier.code.l shouldBe List("x") + rawCall.code shouldBe s"x`a $${1+1} b`" - val List(argument1) = runtimeCall.astChildren.isLiteral.codeExact("\"../\"").l - argument1.order shouldBe 1 + val List(argument1) = rawCall.astChildren.isBlock.l argument1.argumentIndex shouldBe 1 + argument1.astChildren.code.l shouldBe List( + "_tmp_0", + "_tmp_0 = __ecma.Array.factory()", + "_tmp_0.push(\"a \")", + "_tmp_0.push(\" b\")", + "_tmp_0" + ) - val List(argument2) = runtimeCall.astChildren.isLiteral.codeExact("42").l - argument2.order shouldBe 2 + val List(argument2) = rawCall.astChildren.isCall.codeExact("1+1").l argument2.argumentIndex shouldBe 2 + } - val List(argument3) = - runtimeCall.astChildren.isLiteral.codeExact("\"\\..\"").l - argument3.order shouldBe 3 - argument3.argumentIndex shouldBe 3 + "have correct structure for tagged runtime node with complex tag expression" in { + val cpg = code(s"String.raw`../$${42}\\..`") + val List(method) = cpg.method.nameExact(":program").l + val List(methodBlock) = method.astChildren.isBlock.l + + val List(rawCall) = methodBlock.astChildren.isCall.l + rawCall.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + rawCall.name shouldBe "raw" + rawCall.receiver.isCall.code.l shouldBe List("String.raw") + rawCall.code shouldBe "String.raw`../${42}\\..`" + + val List(argument1) = rawCall.astChildren.isBlock.l + argument1.argumentIndex shouldBe 1 + argument1.astChildren.code.l shouldBe List( + "_tmp_0", + "_tmp_0 = __ecma.Array.factory()", + "_tmp_0.push(\"../\")", + "_tmp_0.push(\"\\..\")", + "_tmp_0" + ) + + val List(argument2) = rawCall.astChildren.isLiteral.codeExact("42").l + argument2.argumentIndex shouldBe 2 } "have correct structure for different string literals" in { @@ -859,7 +882,7 @@ class SimpleAstCreationPassTests extends AstJsSrc2CpgSuite { val List(typeDecl) = cpg.typeDecl.nameExact("method").l typeDecl.fullName should endWith("Test0.js::program:method") - val List(binding) = typeDecl.bindsOut.l + val List(binding) = typeDecl.bindsOut.cast[Binding].l binding.name shouldBe "" binding.signature shouldBe "" diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsAstCreationPassTests.scala index 71144ce1141b..33c2d5209640 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsAstCreationPassTests.scala @@ -108,8 +108,7 @@ class TsAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { arg.typeFullName shouldBe Defines.String arg.code shouldBe "arg: string" arg.index shouldBe 1 - val List(parentTypeDecl) = cpg.typeDecl.name(":program").l - parentTypeDecl.bindsOut.flatMap(_.refOut).l should contain(func) + cpg.method("foo").bindingTypeDecl.fullName.l shouldBe List("Test0.ts::program:foo") } "have correct structure for type assertion" in { diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsClassesAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsClassesAstCreationPassTests.scala index fed85c602569..43d24af28886 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsClassesAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsClassesAstCreationPassTests.scala @@ -191,7 +191,10 @@ class TsClassesAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { greeter.fullName shouldBe "Test0.ts::program:Greeter" greeter.filename shouldBe "Test0.ts" greeter.file.name.head shouldBe "Test0.ts" - inside(cpg.typeDecl("Greeter").member.l) { case List(greeting, name, propName, foo, anon, toString) => + inside(cpg.typeDecl("Greeter").member.l) { case List(init, greeting, name, propName, foo, anon, toString) => + init.name shouldBe "" + init.typeFullName shouldBe "Test0.ts::program:Greeter" + init.dynamicTypeHintFullName shouldBe List("Test0.ts::program:Greeter:") greeting.name shouldBe "greeting" greeting.code shouldBe "greeting: string;" name.name shouldBe "name" @@ -339,7 +342,7 @@ class TsClassesAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { val List(credentialsParam) = cpg.parameter.nameExact("credentials").l credentialsParam.typeFullName shouldBe "Test0.ts::program:Test:run:0" // should not produce dangling nodes that are meant to be inside procedures - cpg.all.collectAll[CfgNode].whereNot(_._astIn).size shouldBe 0 + cpg.all.collectAll[CfgNode].whereNot(_.astParent).size shouldBe 0 cpg.identifier.count(_.refsTo.size > 1) shouldBe 0 cpg.identifier.whereNot(_.refsTo).size shouldBe 0 // should not produce assignment calls directly under typedecls @@ -359,7 +362,7 @@ class TsClassesAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { val List(credentialsParam) = cpg.parameter.nameExact("param1_0").l credentialsParam.typeFullName shouldBe "Test0.ts::program:apiCall:0" // should not produce dangling nodes that are meant to be inside procedures - cpg.all.collectAll[CfgNode].whereNot(_._astIn).size shouldBe 0 + cpg.all.collectAll[CfgNode].whereNot(_.astParent).size shouldBe 0 cpg.identifier.count(_.refsTo.size > 1) shouldBe 0 cpg.identifier.whereNot(_.refsTo).size shouldBe 0 // should not produce assignment calls directly under typedecls diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsDecoratorAstCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsDecoratorAstCreationPassTests.scala index c15f2dc879ae..38bd7b036a5c 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsDecoratorAstCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/ast/TsDecoratorAstCreationPassTests.scala @@ -1,8 +1,7 @@ package io.joern.jssrc2cpg.passes.ast import io.joern.jssrc2cpg.testfixtures.AstJsSrc2CpgSuite -import io.joern.x2cpg.frontendspecific.jssrc2cpg.Defines -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TsDecoratorAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { @@ -324,308 +323,6 @@ class TsDecoratorAstCreationPassTests extends AstJsSrc2CpgSuite(".ts") { annotationD.parameterAssign.l shouldBe empty } } - - "create methods for const exports" in { - val cpg = code("export const getApiA = (req: Request) => { const user = req.user as UserDocument; }") - cpg.method.name.sorted.l shouldBe List(":program", "0") - cpg.assignment.code.l shouldBe List( - "const user = req.user as UserDocument", - "const getApiA = (req: Request) => { const user = req.user as UserDocument; }", - "exports.getApiA = getApiA" - ) - inside(cpg.method.name("0").l) { case List(anon) => - anon.fullName shouldBe "Test0.ts::program:0" - anon.ast.isIdentifier.name.l shouldBe List("user", "req") - } - } - - "have correct structure for import assignments" in { - val cpg = code(""" - |import fs = require('fs'); - |import models = require('../models/index'); - |""".stripMargin) - cpg.assignment.code.l shouldBe List("var fs = require(\"fs\")", "var models = require(\"../models/index\")") - cpg.local.code.l shouldBe List("fs", "models") - val List(fsDep, modelsDep) = cpg.dependency.l - fsDep.name shouldBe "fs" - fsDep.dependencyGroupId shouldBe Option("fs") - modelsDep.name shouldBe "models" - modelsDep.dependencyGroupId shouldBe Option("../models/index") - - val List(fs, models) = cpg.imports.l - fs.code shouldBe "import fs = require('fs')" - fs.importedEntity shouldBe Option("fs") - fs.importedAs shouldBe Option("fs") - models.code shouldBe "import models = require('../models/index')" - models.importedEntity shouldBe Option("../models/index") - models.importedAs shouldBe Option("models") - } - - "have correct structure for declared functions" in { - val cpg = code("declare function foo(arg: string): string") - val List(func) = cpg.method("foo").l - func.code shouldBe "declare function foo(arg: string): string" - func.name shouldBe "foo" - func.fullName shouldBe "Test0.ts::program:foo" - val List(_, arg) = cpg.method("foo").parameter.l - arg.name shouldBe "arg" - arg.typeFullName shouldBe Defines.String - arg.code shouldBe "arg: string" - arg.index shouldBe 1 - val List(parentTypeDecl) = cpg.typeDecl.name(":program").l - parentTypeDecl.bindsOut.flatMap(_.refOut).l should contain(func) - } - - } - - "AST generation for TS enums" should { - - "have correct structure for simple enum" in { - val cpg = code(""" - |enum Direction { - | Up = 1, - | Down, - | Left, - | Right, - |} - |""".stripMargin) - inside(cpg.typeDecl("Direction").l) { case List(direction) => - direction.name shouldBe "Direction" - direction.code shouldBe "enum Direction" - direction.fullName shouldBe "Test0.ts::program:Direction" - direction.filename shouldBe "Test0.ts" - direction.file.name.head shouldBe "Test0.ts" - inside(direction.method.name(io.joern.x2cpg.Defines.StaticInitMethodName).l) { case List(init) => - init.block.astChildren.isCall.code.head shouldBe "Up = 1" - } - inside(cpg.typeDecl("Direction").member.l) { case List(up, down, left, right) => - up.name shouldBe "Up" - up.code shouldBe "Up = 1" - down.name shouldBe "Down" - down.code shouldBe "Down" - left.name shouldBe "Left" - left.code shouldBe "Left" - right.name shouldBe "Right" - right.code shouldBe "Right" - } - } - } - - } - - "AST generation for TS classes" should { - - "have correct structure for simple classes" in { - val cpg = code(""" - |class Greeter { - | greeting: string; - | greet() { - | return "Hello, " + this.greeting; - | } - |} - |""".stripMargin) - inside(cpg.typeDecl("Greeter").l) { case List(greeter) => - greeter.name shouldBe "Greeter" - greeter.code shouldBe "class Greeter" - greeter.fullName shouldBe "Test0.ts::program:Greeter" - greeter.filename shouldBe "Test0.ts" - greeter.file.name.head shouldBe "Test0.ts" - val constructor = greeter.method.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).head - greeter.method.isConstructor.head shouldBe constructor - constructor.fullName shouldBe s"Test0.ts::program:Greeter:${io.joern.x2cpg.Defines.ConstructorMethodName}" - inside(cpg.typeDecl("Greeter").member.l) { case List(greeting, greet) => - greeting.name shouldBe "greeting" - greeting.code shouldBe "greeting: string;" - greet.name shouldBe "greet" - greet.dynamicTypeHintFullName shouldBe Seq("Test0.ts::program:Greeter:greet") - } - } - } - - "have correct structure for declared classes with empty constructor" in { - val cpg = code(""" - |declare class Greeter { - | greeting: string; - | constructor(arg: string); - |} - |""".stripMargin) - inside(cpg.typeDecl("Greeter").l) { case List(greeter) => - greeter.name shouldBe "Greeter" - greeter.code shouldBe "class Greeter" - greeter.fullName shouldBe "Test0.ts::program:Greeter" - greeter.filename shouldBe "Test0.ts" - greeter.file.name.head shouldBe "Test0.ts" - val constructor = greeter.method.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).head - constructor.fullName shouldBe s"Test0.ts::program:Greeter:${io.joern.x2cpg.Defines.ConstructorMethodName}" - greeter.method.isConstructor.head shouldBe constructor - inside(cpg.typeDecl("Greeter").member.l) { case List(greeting) => - greeting.name shouldBe "greeting" - greeting.code shouldBe "greeting: string;" - } - } - } - - "have correct modifier" in { - val cpg = code(""" - |abstract class Greeter { - | static a: string; - | private b: string; - | public c: string; - | protected d: string; - | #e: string; // also private - |} - |""".stripMargin) - inside(cpg.typeDecl.name("Greeter.*").l) { case List(greeter) => - greeter.name shouldBe "Greeter" - cpg.typeDecl.isAbstract.head shouldBe greeter - greeter.member.isStatic.head shouldBe greeter.member.name("a").head - greeter.member.isPrivate.l shouldBe greeter.member.name("b", "e").l - greeter.member.isPublic.head shouldBe greeter.member.name("c").head - greeter.member.isProtected.head shouldBe greeter.member.name("d").head - } - } - - "have correct structure for empty interfaces" in { - val cpg = code(""" - |interface A {}; - |interface B {}; - |""".stripMargin) - cpg.method.fullName.sorted.l shouldBe List( - "Test0.ts::program", - s"Test0.ts::program:A:${io.joern.x2cpg.Defines.ConstructorMethodName}", - s"Test0.ts::program:B:${io.joern.x2cpg.Defines.ConstructorMethodName}" - ) - } - - "have correct structure for simple interfaces" in { - val cpg = code(""" - |interface Greeter { - | greeting: string; - | name?: string; - | [propName: string]: any; - | "foo": string; - | (source: string, subString: string): boolean; - |} - |""".stripMargin) - inside(cpg.typeDecl("Greeter").l) { case List(greeter) => - greeter.name shouldBe "Greeter" - greeter.code shouldBe "interface Greeter" - greeter.fullName shouldBe "Test0.ts::program:Greeter" - greeter.filename shouldBe "Test0.ts" - greeter.file.name.head shouldBe "Test0.ts" - inside(cpg.typeDecl("Greeter").member.l) { case List(greeting, name, propName, foo, anon) => - greeting.name shouldBe "greeting" - greeting.code shouldBe "greeting: string;" - name.name shouldBe "name" - name.code shouldBe "name?: string;" - propName.name shouldBe "propName" - propName.code shouldBe "[propName: string]: any;" - foo.name shouldBe "foo" - foo.code shouldBe "\"foo\": string;" - anon.name shouldBe "0" - anon.dynamicTypeHintFullName shouldBe Seq("Test0.ts::program:Greeter:0") - anon.code shouldBe "(source: string, subString: string): boolean;" - } - inside(cpg.typeDecl("Greeter").method.l) { case List(constructor, anon) => - constructor.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName - constructor.fullName shouldBe s"Test0.ts::program:Greeter:${io.joern.x2cpg.Defines.ConstructorMethodName}" - constructor.code shouldBe "new: Greeter" - greeter.method.isConstructor.head shouldBe constructor - anon.name shouldBe "0" - anon.fullName shouldBe "Test0.ts::program:Greeter:0" - anon.code shouldBe "(source: string, subString: string): boolean;" - anon.parameter.name.l shouldBe List("this", "source", "subString") - anon.parameter.code.l shouldBe List("this", "source: string", "subString: string") - } - } - } - - "have correct structure for interface constructor" in { - val cpg = code(""" - |interface Greeter { - | new (param: string) : Greeter - |} - |""".stripMargin) - inside(cpg.typeDecl("Greeter").l) { case List(greeter) => - greeter.name shouldBe "Greeter" - greeter.code shouldBe "interface Greeter" - greeter.fullName shouldBe "Test0.ts::program:Greeter" - greeter.filename shouldBe "Test0.ts" - greeter.file.name.head shouldBe "Test0.ts" - inside(cpg.typeDecl("Greeter").method.l) { case List(constructor) => - constructor.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName - constructor.fullName shouldBe s"Test0.ts::program:Greeter:${io.joern.x2cpg.Defines.ConstructorMethodName}" - constructor.code shouldBe "new (param: string) : Greeter" - constructor.parameter.name.l shouldBe List("this", "param") - constructor.parameter.code.l shouldBe List("this", "param: string") - greeter.method.isConstructor.head shouldBe constructor - } - } - } - - "have correct structure for simple namespace" in { - val cpg = code(""" - |namespace A { - | class Foo {}; - |} - |""".stripMargin) - inside(cpg.namespaceBlock("A").l) { case List(namespaceA) => - namespaceA.code should startWith("namespace A") - namespaceA.fullName shouldBe "Test0.ts::program:A" - namespaceA.typeDecl.name("Foo").head.fullName shouldBe "Test0.ts::program:A:Foo" - } - } - - "have correct structure for nested namespaces" in { - val cpg = code(""" - |namespace A { - | namespace B { - | namespace C { - | class Foo {}; - | } - | } - |} - |""".stripMargin) - inside(cpg.namespaceBlock("A").l) { case List(namespaceA) => - namespaceA.code should startWith("namespace A") - namespaceA.fullName shouldBe "Test0.ts::program:A" - namespaceA.astChildren.astChildren.isNamespaceBlock.name("B").head shouldBe cpg.namespaceBlock("B").head - } - inside(cpg.namespaceBlock("B").l) { case List(namespaceB) => - namespaceB.code should startWith("namespace B") - namespaceB.fullName shouldBe "Test0.ts::program:A:B" - namespaceB.astChildren.astChildren.isNamespaceBlock.name("C").head shouldBe cpg.namespaceBlock("C").head - } - inside(cpg.namespaceBlock("C").l) { case List(namespaceC) => - namespaceC.code should startWith("namespace C") - namespaceC.fullName shouldBe "Test0.ts::program:A:B:C" - namespaceC.typeDecl.name("Foo").head.fullName shouldBe "Test0.ts::program:A:B:C:Foo" - } - } - - "have correct structure for nested namespaces with path" in { - val cpg = code(""" - |namespace A.B.C { - | class Foo {}; - |} - |""".stripMargin) - inside(cpg.namespaceBlock("A").l) { case List(namespaceA) => - namespaceA.code should startWith("namespace A") - namespaceA.fullName shouldBe "Test0.ts::program:A" - namespaceA.astChildren.isNamespaceBlock.name("B").head shouldBe cpg.namespaceBlock("B").head - } - inside(cpg.namespaceBlock("B").l) { case List(b) => - b.code should startWith("B.C") - b.fullName shouldBe "Test0.ts::program:A:B" - b.astChildren.isNamespaceBlock.name("C").head shouldBe cpg.namespaceBlock("C").head - } - inside(cpg.namespaceBlock("C").l) { case List(c) => - c.code should startWith("C") - c.fullName shouldBe "Test0.ts::program:A:B:C" - c.typeDecl.name("Foo").head.fullName shouldBe "Test0.ts::program:A:B:C:Foo" - } - } - } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/DependencyCfgCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/DependencyCfgCreationPassTests.scala index c83d24acb0b4..7e914d662a6f 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/DependencyCfgCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/DependencyCfgCreationPassTests.scala @@ -10,16 +10,16 @@ class DependencyCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTe "CFG generation for global builtins" should { "be correct for JSON.parse" in { implicit val cpg: Cpg = code("""JSON.parse("foo");""") - succOf(":program") shouldBe expected((""""foo"""", AlwaysEdge)) - succOf(""""foo"""") shouldBe expected(("""JSON.parse("foo")""", AlwaysEdge)) - succOf("""JSON.parse("foo")""") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected((""""foo"""", AlwaysEdge)) + succOf(""""foo"""") should contain theSameElementsAs expected(("""JSON.parse("foo")""", AlwaysEdge)) + succOf("""JSON.parse("foo")""") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "have correct structure for JSON.stringify" in { implicit val cpg: Cpg = code("""JSON.stringify(foo);""") - succOf(":program") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("JSON.stringify(foo)", AlwaysEdge)) - succOf("JSON.stringify(foo)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("JSON.stringify(foo)", AlwaysEdge)) + succOf("JSON.stringify(foo)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/JsClassesCfgCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/JsClassesCfgCreationPassTests.scala index 8301cf8a4afc..f827e03dbaab 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/JsClassesCfgCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/JsClassesCfgCreationPassTests.scala @@ -11,61 +11,65 @@ class JsClassesCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTes "CFG generation for constructor" should { "be correct for simple new" in { implicit val cpg: Cpg = code("new MyClass()") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected((".alloc", AlwaysEdge)) - succOf(".alloc") shouldBe expected(("_tmp_0 = .alloc", AlwaysEdge)) - succOf("_tmp_0 = .alloc") shouldBe expected(("MyClass", AlwaysEdge)) - succOf("MyClass") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("new MyClass()", AlwaysEdge)) - succOf("new MyClass()", NodeTypes.CALL) shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("new MyClass()", AlwaysEdge)) - succOf("new MyClass()") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected((".alloc", AlwaysEdge)) + succOf(".alloc") should contain theSameElementsAs expected(("_tmp_0 = .alloc", AlwaysEdge)) + succOf("_tmp_0 = .alloc") should contain theSameElementsAs expected(("MyClass", AlwaysEdge)) + succOf("MyClass") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("new MyClass()", AlwaysEdge)) + succOf("new MyClass()", NodeTypes.CALL) should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("new MyClass()", AlwaysEdge)) + succOf("new MyClass()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for simple new with arguments" in { implicit val cpg: Cpg = code("new MyClass(arg1, arg2)") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected((".alloc", AlwaysEdge)) - succOf(".alloc") shouldBe expected(("_tmp_0 = .alloc", AlwaysEdge)) - succOf("_tmp_0 = .alloc") shouldBe expected(("MyClass", AlwaysEdge)) - succOf("MyClass") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("arg1", AlwaysEdge)) - succOf("arg1") shouldBe expected(("arg2", AlwaysEdge)) - succOf("arg2") shouldBe expected(("new MyClass(arg1, arg2)", AlwaysEdge)) - succOf("new MyClass(arg1, arg2)", NodeTypes.CALL) shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("new MyClass(arg1, arg2)", AlwaysEdge)) - succOf("new MyClass(arg1, arg2)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected((".alloc", AlwaysEdge)) + succOf(".alloc") should contain theSameElementsAs expected(("_tmp_0 = .alloc", AlwaysEdge)) + succOf("_tmp_0 = .alloc") should contain theSameElementsAs expected(("MyClass", AlwaysEdge)) + succOf("MyClass") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("arg1", AlwaysEdge)) + succOf("arg1") should contain theSameElementsAs expected(("arg2", AlwaysEdge)) + succOf("arg2") should contain theSameElementsAs expected(("new MyClass(arg1, arg2)", AlwaysEdge)) + succOf("new MyClass(arg1, arg2)", NodeTypes.CALL) should contain theSameElementsAs expected( + ("_tmp_0", 2, AlwaysEdge) + ) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("new MyClass(arg1, arg2)", AlwaysEdge)) + succOf("new MyClass(arg1, arg2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for new with access path" in { implicit val cpg: Cpg = code("new foo.bar.MyClass()") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected((".alloc", AlwaysEdge)) - succOf(".alloc") shouldBe expected(("_tmp_0 = .alloc", AlwaysEdge)) - succOf("_tmp_0 = .alloc") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("bar", AlwaysEdge)) - succOf("bar") shouldBe expected(("foo.bar", AlwaysEdge)) - succOf("foo.bar") shouldBe expected(("MyClass", AlwaysEdge)) - succOf("MyClass") shouldBe expected(("foo.bar.MyClass", AlwaysEdge)) - succOf("foo.bar.MyClass") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("new foo.bar.MyClass()", AlwaysEdge)) - succOf("new foo.bar.MyClass()", NodeTypes.CALL) shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("new foo.bar.MyClass()", AlwaysEdge)) - succOf("new foo.bar.MyClass()") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected((".alloc", AlwaysEdge)) + succOf(".alloc") should contain theSameElementsAs expected(("_tmp_0 = .alloc", AlwaysEdge)) + succOf("_tmp_0 = .alloc") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("bar", AlwaysEdge)) + succOf("bar") should contain theSameElementsAs expected(("foo.bar", AlwaysEdge)) + succOf("foo.bar") should contain theSameElementsAs expected(("MyClass", AlwaysEdge)) + succOf("MyClass") should contain theSameElementsAs expected(("foo.bar.MyClass", AlwaysEdge)) + succOf("foo.bar.MyClass") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("new foo.bar.MyClass()", AlwaysEdge)) + succOf("new foo.bar.MyClass()", NodeTypes.CALL) should contain theSameElementsAs expected( + ("_tmp_0", 2, AlwaysEdge) + ) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("new foo.bar.MyClass()", AlwaysEdge)) + succOf("new foo.bar.MyClass()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be structure for throw new exceptions" in { implicit val cpg: Cpg = code("function foo() { throw new Foo() }") - succOf("foo", NodeTypes.METHOD) shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected((".alloc", AlwaysEdge)) - succOf(".alloc") shouldBe expected(("_tmp_0 = .alloc", AlwaysEdge)) - succOf("_tmp_0 = .alloc") shouldBe expected(("Foo", AlwaysEdge)) - succOf("Foo") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("new Foo()", AlwaysEdge)) - succOf("new Foo()", NodeTypes.CALL) shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("new Foo()", AlwaysEdge)) - succOf("new Foo()") shouldBe expected(("throw new Foo()", AlwaysEdge)) - succOf("throw new Foo()") shouldBe expected(("RET", AlwaysEdge)) + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected((".alloc", AlwaysEdge)) + succOf(".alloc") should contain theSameElementsAs expected(("_tmp_0 = .alloc", AlwaysEdge)) + succOf("_tmp_0 = .alloc") should contain theSameElementsAs expected(("Foo", AlwaysEdge)) + succOf("Foo") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("new Foo()", AlwaysEdge)) + succOf("new Foo()", NodeTypes.CALL) should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("new Foo()", AlwaysEdge)) + succOf("new Foo()") should contain theSameElementsAs expected(("throw new Foo()", AlwaysEdge)) + succOf("throw new Foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } @@ -78,10 +82,10 @@ class JsClassesCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTes | } |} |""".stripMargin) - succOf("foo", NodeTypes.METHOD) shouldBe expected(("bar", AlwaysEdge)) - succOf("bar") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("bar()", AlwaysEdge)) - succOf("bar()") shouldBe expected(("RET", 2, AlwaysEdge)) + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("bar", AlwaysEdge)) + succOf("bar") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("bar()", AlwaysEdge)) + succOf("bar()") should contain theSameElementsAs expected(("RET", 2, AlwaysEdge)) } "be correct for methods in class type decls with assignment" in { @@ -92,17 +96,17 @@ class JsClassesCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTes | } |} |""".stripMargin) - succOf(":program") shouldBe expected(("a", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("a", AlwaysEdge)) // call to constructor of ClassA - succOf("a") shouldBe expected(("class ClassA", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("class ClassA", AlwaysEdge)) } "be correct for outer method of anonymous class declaration" in { implicit val cpg: Cpg = code("var a = class {}") - succOf(":program") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("class 0", AlwaysEdge)) - succOf("class 0") shouldBe expected(("var a = class {}", AlwaysEdge)) - succOf("var a = class {}") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("class 0", AlwaysEdge)) + succOf("class 0") should contain theSameElementsAs expected(("var a = class {}", AlwaysEdge)) + succOf("var a = class {}") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/MixedCfgCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/MixedCfgCreationPassTests.scala index bef8d7d9f81c..f3c619eed31c 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/MixedCfgCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/MixedCfgCreationPassTests.scala @@ -7,164 +7,172 @@ import io.joern.x2cpg.passes.controlflow.cfgcreation.Cfg.TrueEdge import io.joern.x2cpg.testfixtures.CfgTestFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.NodeTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MixedCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCpg()) { "CFG generation for destructing assignment" should { "be correct for object destruction assignment with declaration" in { implicit val cpg: Cpg = code("var {a, b} = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("_tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a") shouldBe expected(("a = _tmp_0.a", AlwaysEdge)) - - succOf("a = _tmp_0.a") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("_tmp_0.b", AlwaysEdge)) - succOf("_tmp_0.b") shouldBe expected(("b = _tmp_0.b", AlwaysEdge)) - succOf("b = _tmp_0.b") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("var {a, b} = x", AlwaysEdge)) - succOf("var {a, b} = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("_tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a") should contain theSameElementsAs expected(("a = _tmp_0.a", AlwaysEdge)) + + succOf("a = _tmp_0.a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("_tmp_0.b", AlwaysEdge)) + succOf("_tmp_0.b") should contain theSameElementsAs expected(("b = _tmp_0.b", AlwaysEdge)) + succOf("b = _tmp_0.b") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var {a, b} = x", AlwaysEdge)) + succOf("var {a, b} = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for object destruction assignment with declaration and ternary init" in { implicit val cpg: Cpg = code("const { a, b } = test() ? foo() : bar()") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("test", AlwaysEdge)) - succOf("test") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("test()", AlwaysEdge)) - succOf("test()") shouldBe expected(("foo", TrueEdge), ("bar", FalseEdge)) - succOf("foo") shouldBe expected(("this", 1, AlwaysEdge)) - succOf("this", 2) shouldBe expected(("foo()", AlwaysEdge)) - succOf("bar()") shouldBe expected(("test() ? foo() : bar()", AlwaysEdge)) - succOf("foo()") shouldBe expected(("test() ? foo() : bar()", AlwaysEdge)) - succOf("test() ? foo() : bar()") shouldBe expected(("_tmp_0 = test() ? foo() : bar()", AlwaysEdge)) - succOf("_tmp_0 = test() ? foo() : bar()") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("_tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a") shouldBe expected(("a = _tmp_0.a", AlwaysEdge)) - succOf("a = _tmp_0.a") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("_tmp_0.b", AlwaysEdge)) - succOf("_tmp_0.b") shouldBe expected(("b = _tmp_0.b", AlwaysEdge)) - succOf("b = _tmp_0.b") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("const { a, b } = test() ? foo() : bar()", AlwaysEdge)) - succOf("const { a, b } = test() ? foo() : bar()") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("test", AlwaysEdge)) + succOf("test") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("test()", AlwaysEdge)) + succOf("test()") should contain theSameElementsAs expected(("foo", TrueEdge), ("bar", FalseEdge)) + succOf("foo") should contain theSameElementsAs expected(("this", 1, AlwaysEdge)) + succOf("this", 2) should contain theSameElementsAs expected(("foo()", AlwaysEdge)) + succOf("bar()") should contain theSameElementsAs expected(("test() ? foo() : bar()", AlwaysEdge)) + succOf("foo()") should contain theSameElementsAs expected(("test() ? foo() : bar()", AlwaysEdge)) + succOf("test() ? foo() : bar()") should contain theSameElementsAs expected( + ("_tmp_0 = test() ? foo() : bar()", AlwaysEdge) + ) + succOf("_tmp_0 = test() ? foo() : bar()") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("_tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a") should contain theSameElementsAs expected(("a = _tmp_0.a", AlwaysEdge)) + succOf("a = _tmp_0.a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("_tmp_0.b", AlwaysEdge)) + succOf("_tmp_0.b") should contain theSameElementsAs expected(("b = _tmp_0.b", AlwaysEdge)) + succOf("b = _tmp_0.b") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected( + ("const { a, b } = test() ? foo() : bar()", AlwaysEdge) + ) + succOf("const { a, b } = test() ? foo() : bar()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for object destruction assignment with reassignment" in { implicit val cpg: Cpg = code("var {a: n, b: m} = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("n", AlwaysEdge)) - succOf("n") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("_tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a") shouldBe expected(("n = _tmp_0.a", AlwaysEdge)) - - succOf("n = _tmp_0.a") shouldBe expected(("m", AlwaysEdge)) - succOf("m") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0.b", AlwaysEdge)) - succOf("_tmp_0.b") shouldBe expected(("m = _tmp_0.b", AlwaysEdge)) - succOf("m = _tmp_0.b") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("var {a: n, b: m} = x", AlwaysEdge)) - succOf("var {a: n, b: m} = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("n", AlwaysEdge)) + succOf("n") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a") should contain theSameElementsAs expected(("n = _tmp_0.a", AlwaysEdge)) + + succOf("n = _tmp_0.a") should contain theSameElementsAs expected(("m", AlwaysEdge)) + succOf("m") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0.b", AlwaysEdge)) + succOf("_tmp_0.b") should contain theSameElementsAs expected(("m = _tmp_0.b", AlwaysEdge)) + succOf("m = _tmp_0.b") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var {a: n, b: m} = x", AlwaysEdge)) + succOf("var {a: n, b: m} = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for object destruction assignment with reassignment and defaults" in { implicit val cpg: Cpg = code("var {a: n = 1, b: m = 2} = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - succOf("_tmp_0 = x") shouldBe expected(("n", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("n", AlwaysEdge)) // test statement - succOf("n") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("_tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a") shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("_tmp_0.a === void 0", AlwaysEdge)) + succOf("n") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a") should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("_tmp_0.a === void 0", AlwaysEdge)) // true, false cases - succOf("_tmp_0.a === void 0") shouldBe expected(("1", TrueEdge), ("_tmp_0", 2, FalseEdge)) - succOf("_tmp_0", 2) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("_tmp_0.a", 1, AlwaysEdge)) - succOf("_tmp_0.a", 1) shouldBe expected(("_tmp_0.a === void 0 ? 1 : _tmp_0.a", AlwaysEdge)) - succOf("1") shouldBe expected(("_tmp_0.a === void 0 ? 1 : _tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a === void 0 ? 1 : _tmp_0.a") shouldBe + succOf("_tmp_0.a === void 0") should contain theSameElementsAs expected(("1", TrueEdge), ("_tmp_0", 2, FalseEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("_tmp_0.a", 1, AlwaysEdge)) + succOf("_tmp_0.a", 1) should contain theSameElementsAs expected( + ("_tmp_0.a === void 0 ? 1 : _tmp_0.a", AlwaysEdge) + ) + succOf("1") should contain theSameElementsAs expected(("_tmp_0.a === void 0 ? 1 : _tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a === void 0 ? 1 : _tmp_0.a") should contain theSameElementsAs expected(("n = _tmp_0.a === void 0 ? 1 : _tmp_0.a", AlwaysEdge)) - succOf("n = _tmp_0.a === void 0 ? 1 : _tmp_0.a") shouldBe + succOf("n = _tmp_0.a === void 0 ? 1 : _tmp_0.a") should contain theSameElementsAs expected(("m", AlwaysEdge)) // test statement - succOf("m") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0.b", AlwaysEdge)) - succOf("_tmp_0.b") shouldBe expected(("void 0", 1, AlwaysEdge)) - succOf("void 0", 1) shouldBe expected(("_tmp_0.b === void 0", AlwaysEdge)) + succOf("m") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0.b", AlwaysEdge)) + succOf("_tmp_0.b") should contain theSameElementsAs expected(("void 0", 1, AlwaysEdge)) + succOf("void 0", 1) should contain theSameElementsAs expected(("_tmp_0.b === void 0", AlwaysEdge)) // true, false cases - succOf("_tmp_0.b === void 0") shouldBe expected(("2", TrueEdge), ("_tmp_0", 4, FalseEdge)) - succOf("_tmp_0", 4) shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("_tmp_0.b", 1, AlwaysEdge)) - succOf("_tmp_0.b", 1) shouldBe expected(("_tmp_0.b === void 0 ? 2 : _tmp_0.b", AlwaysEdge)) - succOf("2") shouldBe expected(("_tmp_0.b === void 0 ? 2 : _tmp_0.b", AlwaysEdge)) - succOf("_tmp_0.b === void 0 ? 2 : _tmp_0.b") shouldBe + succOf("_tmp_0.b === void 0") should contain theSameElementsAs expected(("2", TrueEdge), ("_tmp_0", 4, FalseEdge)) + succOf("_tmp_0", 4) should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("_tmp_0.b", 1, AlwaysEdge)) + succOf("_tmp_0.b", 1) should contain theSameElementsAs expected( + ("_tmp_0.b === void 0 ? 2 : _tmp_0.b", AlwaysEdge) + ) + succOf("2") should contain theSameElementsAs expected(("_tmp_0.b === void 0 ? 2 : _tmp_0.b", AlwaysEdge)) + succOf("_tmp_0.b === void 0 ? 2 : _tmp_0.b") should contain theSameElementsAs expected(("m = _tmp_0.b === void 0 ? 2 : _tmp_0.b", AlwaysEdge)) - succOf("m = _tmp_0.b === void 0 ? 2 : _tmp_0.b") shouldBe + succOf("m = _tmp_0.b === void 0 ? 2 : _tmp_0.b") should contain theSameElementsAs expected(("_tmp_0", 5, AlwaysEdge)) - succOf("_tmp_0", 5) shouldBe expected(("var {a: n = 1, b: m = 2} = x", AlwaysEdge)) - succOf("var {a: n = 1, b: m = 2} = x") shouldBe expected(("RET", AlwaysEdge)) + succOf("_tmp_0", 5) should contain theSameElementsAs expected(("var {a: n = 1, b: m = 2} = x", AlwaysEdge)) + succOf("var {a: n = 1, b: m = 2} = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for object destruction assignment with rest" in { implicit val cpg: Cpg = code("var {a, ...rest} = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("_tmp_0.a", AlwaysEdge)) - succOf("_tmp_0.a") shouldBe expected(("a = _tmp_0.a", AlwaysEdge)) - - succOf("a = _tmp_0.a") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("rest", AlwaysEdge)) - succOf("rest") shouldBe expected(("...rest", AlwaysEdge)) - succOf("...rest") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - - succOf("_tmp_0", 3) shouldBe expected(("var {a, ...rest} = x", AlwaysEdge)) - succOf("var {a, ...rest} = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("_tmp_0.a", AlwaysEdge)) + succOf("_tmp_0.a") should contain theSameElementsAs expected(("a = _tmp_0.a", AlwaysEdge)) + + succOf("a = _tmp_0.a") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("rest", AlwaysEdge)) + succOf("rest") should contain theSameElementsAs expected(("...rest", AlwaysEdge)) + succOf("...rest") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var {a, ...rest} = x", AlwaysEdge)) + succOf("var {a, ...rest} = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for object destruction assignment with computed property name" in { implicit val cpg: Cpg = code("var {[propName]: n} = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("n", AlwaysEdge)) - succOf("n") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("propName", AlwaysEdge)) - succOf("propName") shouldBe expected(("_tmp_0.propName", AlwaysEdge)) - succOf("_tmp_0.propName") shouldBe expected(("n = _tmp_0.propName", AlwaysEdge)) - - succOf("n = _tmp_0.propName") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("var {[propName]: n} = x", AlwaysEdge)) - succOf("var {[propName]: n} = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("n", AlwaysEdge)) + succOf("n") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("propName", AlwaysEdge)) + succOf("propName") should contain theSameElementsAs expected(("_tmp_0.propName", AlwaysEdge)) + succOf("_tmp_0.propName") should contain theSameElementsAs expected(("n = _tmp_0.propName", AlwaysEdge)) + + succOf("n = _tmp_0.propName") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("var {[propName]: n} = x", AlwaysEdge)) + succOf("var {[propName]: n} = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested object destruction assignment with defaults as parameter" in { @@ -172,46 +180,50 @@ class MixedCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCpg |function userId({id = {}, b} = {}) { | return id |}""".stripMargin) - succOf("userId", NodeTypes.METHOD) shouldBe expected(("_tmp_1", AlwaysEdge)) - succOf("_tmp_1") shouldBe expected(("param1_0", AlwaysEdge)) - succOf("param1_0") shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("param1_0 === void 0", AlwaysEdge)) - succOf("param1_0 === void 0") shouldBe expected( + succOf("userId", NodeTypes.METHOD) should contain theSameElementsAs expected(("_tmp_1", AlwaysEdge)) + succOf("_tmp_1") should contain theSameElementsAs expected(("param1_0", AlwaysEdge)) + succOf("param1_0") should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("param1_0 === void 0", AlwaysEdge)) + succOf("param1_0 === void 0") should contain theSameElementsAs expected( ("_tmp_0", TrueEdge), // holds {} ("param1_0", 1, FalseEdge) ) - succOf("param1_0", 1) shouldBe expected(("param1_0 === void 0 ? {} : param1_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("param1_0 === void 0 ? {} : param1_0", AlwaysEdge)) - succOf("param1_0 === void 0 ? {} : param1_0") shouldBe expected( + succOf("param1_0", 1) should contain theSameElementsAs expected( + ("param1_0 === void 0 ? {} : param1_0", AlwaysEdge) + ) + succOf("_tmp_0") should contain theSameElementsAs expected(("param1_0 === void 0 ? {} : param1_0", AlwaysEdge)) + succOf("param1_0 === void 0 ? {} : param1_0") should contain theSameElementsAs expected( ("_tmp_1 = param1_0 === void 0 ? {} : param1_0", AlwaysEdge) ) - succOf("_tmp_1 = param1_0 === void 0 ? {} : param1_0") shouldBe expected(("id", AlwaysEdge)) - succOf("id") shouldBe expected(("_tmp_1", 1, AlwaysEdge)) - succOf("_tmp_1", 1) shouldBe expected(("id", 1, AlwaysEdge)) - succOf("id", 1) shouldBe expected(("_tmp_1.id", AlwaysEdge)) - succOf("_tmp_1.id") shouldBe expected(("void 0", 1, AlwaysEdge)) - succOf("void 0", 1) shouldBe expected(("_tmp_1.id === void 0", AlwaysEdge)) - succOf("_tmp_1.id === void 0") shouldBe expected( + succOf("_tmp_1 = param1_0 === void 0 ? {} : param1_0") should contain theSameElementsAs expected( + ("id", AlwaysEdge) + ) + succOf("id") should contain theSameElementsAs expected(("_tmp_1", 1, AlwaysEdge)) + succOf("_tmp_1", 1) should contain theSameElementsAs expected(("id", 1, AlwaysEdge)) + succOf("id", 1) should contain theSameElementsAs expected(("_tmp_1.id", AlwaysEdge)) + succOf("_tmp_1.id") should contain theSameElementsAs expected(("void 0", 1, AlwaysEdge)) + succOf("void 0", 1) should contain theSameElementsAs expected(("_tmp_1.id === void 0", AlwaysEdge)) + succOf("_tmp_1.id === void 0") should contain theSameElementsAs expected( ("_tmp_2", TrueEdge), // holds {} ("_tmp_1", 2, FalseEdge) ) - succOf("_tmp_2") shouldBe expected(("_tmp_1.id === void 0 ? {} : _tmp_1.id", AlwaysEdge)) - succOf("_tmp_1", 2) shouldBe expected(("id", 2, AlwaysEdge)) + succOf("_tmp_2") should contain theSameElementsAs expected(("_tmp_1.id === void 0 ? {} : _tmp_1.id", AlwaysEdge)) + succOf("_tmp_1", 2) should contain theSameElementsAs expected(("id", 2, AlwaysEdge)) - succOf("_tmp_1.id === void 0 ? {} : _tmp_1.id") shouldBe expected( + succOf("_tmp_1.id === void 0 ? {} : _tmp_1.id") should contain theSameElementsAs expected( ("id = _tmp_1.id === void 0 ? {} : _tmp_1.id", AlwaysEdge) ) - succOf("id", 2) shouldBe expected(("_tmp_1.id", 1, AlwaysEdge)) + succOf("id", 2) should contain theSameElementsAs expected(("_tmp_1.id", 1, AlwaysEdge)) - succOf("id = _tmp_1.id === void 0 ? {} : _tmp_1.id") shouldBe expected(("b", AlwaysEdge)) + succOf("id = _tmp_1.id === void 0 ? {} : _tmp_1.id") should contain theSameElementsAs expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_1", 3, AlwaysEdge)) - succOf("_tmp_1", 3) shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("_tmp_1.b", AlwaysEdge)) - succOf("_tmp_1.b") shouldBe expected(("b = _tmp_1.b", AlwaysEdge)) - succOf("b = _tmp_1.b") shouldBe expected(("_tmp_1", 4, AlwaysEdge)) - succOf("_tmp_1", 4) shouldBe expected(("{id = {}, b} = {}", 1, AlwaysEdge)) - succOf("{id = {}, b} = {}", 1) shouldBe expected(("id", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_1", 3, AlwaysEdge)) + succOf("_tmp_1", 3) should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("_tmp_1.b", AlwaysEdge)) + succOf("_tmp_1.b") should contain theSameElementsAs expected(("b = _tmp_1.b", AlwaysEdge)) + succOf("b = _tmp_1.b") should contain theSameElementsAs expected(("_tmp_1", 4, AlwaysEdge)) + succOf("_tmp_1", 4) should contain theSameElementsAs expected(("{id = {}, b} = {}", 1, AlwaysEdge)) + succOf("{id = {}, b} = {}", 1) should contain theSameElementsAs expected(("id", AlwaysEdge)) } @@ -220,151 +232,163 @@ class MixedCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCpg |function userId({id}) { | return id |}""".stripMargin) - succOf("userId", NodeTypes.METHOD) shouldBe expected(("id", AlwaysEdge)) - succOf("id") shouldBe expected(("param1_0", AlwaysEdge)) - succOf("param1_0") shouldBe expected(("id", 1, AlwaysEdge)) - succOf("id", 1) shouldBe expected(("param1_0.id", AlwaysEdge)) - succOf("param1_0.id") shouldBe expected(("id = param1_0.id", AlwaysEdge)) - succOf("id = param1_0.id") shouldBe expected(("id", 2, AlwaysEdge)) - succOf("id", 2) shouldBe expected(("return id", AlwaysEdge)) - succOf("return id") shouldBe expected(("RET", AlwaysEdge)) + succOf("userId", NodeTypes.METHOD) should contain theSameElementsAs expected(("id", AlwaysEdge)) + succOf("id") should contain theSameElementsAs expected(("param1_0", AlwaysEdge)) + succOf("param1_0") should contain theSameElementsAs expected(("id", 1, AlwaysEdge)) + succOf("id", 1) should contain theSameElementsAs expected(("param1_0.id", AlwaysEdge)) + succOf("param1_0.id") should contain theSameElementsAs expected(("id = param1_0.id", AlwaysEdge)) + succOf("id = param1_0.id") should contain theSameElementsAs expected(("id", 2, AlwaysEdge)) + succOf("id", 2) should contain theSameElementsAs expected(("return id", AlwaysEdge)) + succOf("return id") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment with declaration" in { implicit val cpg: Cpg = code("var [a, b] = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("_tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0]") shouldBe expected(("a = _tmp_0[0]", AlwaysEdge)) - - succOf("a = _tmp_0[0]") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("_tmp_0[1]", AlwaysEdge)) - succOf("_tmp_0[1]") shouldBe expected(("b = _tmp_0[1]", AlwaysEdge)) - succOf("b = _tmp_0[1]") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("var [a, b] = x", AlwaysEdge)) - succOf("var [a, b] = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("_tmp_0[0]", AlwaysEdge)) + succOf("_tmp_0[0]") should contain theSameElementsAs expected(("a = _tmp_0[0]", AlwaysEdge)) + + succOf("a = _tmp_0[0]") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("_tmp_0[1]", AlwaysEdge)) + succOf("_tmp_0[1]") should contain theSameElementsAs expected(("b = _tmp_0[1]", AlwaysEdge)) + succOf("b = _tmp_0[1]") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var [a, b] = x", AlwaysEdge)) + succOf("var [a, b] = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment without declaration" in { implicit val cpg: Cpg = code("[a, b] = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("_tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0]") shouldBe expected(("a = _tmp_0[0]", AlwaysEdge)) - - succOf("a = _tmp_0[0]") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("_tmp_0[1]", AlwaysEdge)) - succOf("_tmp_0[1]") shouldBe expected(("b = _tmp_0[1]", AlwaysEdge)) - succOf("b = _tmp_0[1]") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("[a, b] = x", AlwaysEdge)) - succOf("[a, b] = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("_tmp_0[0]", AlwaysEdge)) + succOf("_tmp_0[0]") should contain theSameElementsAs expected(("a = _tmp_0[0]", AlwaysEdge)) + + succOf("a = _tmp_0[0]") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("_tmp_0[1]", AlwaysEdge)) + succOf("_tmp_0[1]") should contain theSameElementsAs expected(("b = _tmp_0[1]", AlwaysEdge)) + succOf("b = _tmp_0[1]") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("[a, b] = x", AlwaysEdge)) + succOf("[a, b] = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment with defaults" in { implicit val cpg: Cpg = code("var [a = 1, b = 2] = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) // test statement - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("_tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0]") shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("_tmp_0[0] === void 0", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("_tmp_0[0]", AlwaysEdge)) + succOf("_tmp_0[0]") should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("_tmp_0[0] === void 0", AlwaysEdge)) // true, false cases - succOf("_tmp_0[0] === void 0") shouldBe expected(("1", TrueEdge), ("_tmp_0", 2, FalseEdge)) - succOf("_tmp_0", 2) shouldBe expected(("0", 1, AlwaysEdge)) - succOf("0", 1) shouldBe expected(("_tmp_0[0]", 1, AlwaysEdge)) - succOf("_tmp_0[0]", 1) shouldBe expected(("_tmp_0[0] === void 0 ? 1 : _tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0] === void 0 ? 1 : _tmp_0[0]") shouldBe expected( + succOf("_tmp_0[0] === void 0") should contain theSameElementsAs expected( + ("1", TrueEdge), + ("_tmp_0", 2, FalseEdge) + ) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("0", 1, AlwaysEdge)) + succOf("0", 1) should contain theSameElementsAs expected(("_tmp_0[0]", 1, AlwaysEdge)) + succOf("_tmp_0[0]", 1) should contain theSameElementsAs expected( + ("_tmp_0[0] === void 0 ? 1 : _tmp_0[0]", AlwaysEdge) + ) + succOf("_tmp_0[0] === void 0 ? 1 : _tmp_0[0]") should contain theSameElementsAs expected( ("a = _tmp_0[0] === void 0 ? 1 : _tmp_0[0]", AlwaysEdge) ) - succOf("a = _tmp_0[0] === void 0 ? 1 : _tmp_0[0]") shouldBe expected(("b", AlwaysEdge)) + succOf("a = _tmp_0[0] === void 0 ? 1 : _tmp_0[0]") should contain theSameElementsAs expected(("b", AlwaysEdge)) // test statement - succOf("b") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("1", 1, AlwaysEdge)) - succOf("1", 1) shouldBe expected(("_tmp_0[1]", AlwaysEdge)) - succOf("_tmp_0[1]") shouldBe expected(("void 0", 1, AlwaysEdge)) - succOf("void 0", 1) shouldBe expected(("_tmp_0[1] === void 0", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("1", 1, AlwaysEdge)) + succOf("1", 1) should contain theSameElementsAs expected(("_tmp_0[1]", AlwaysEdge)) + succOf("_tmp_0[1]") should contain theSameElementsAs expected(("void 0", 1, AlwaysEdge)) + succOf("void 0", 1) should contain theSameElementsAs expected(("_tmp_0[1] === void 0", AlwaysEdge)) // true, false cases - succOf("_tmp_0[1] === void 0") shouldBe expected(("2", TrueEdge), ("_tmp_0", 4, FalseEdge)) - succOf("_tmp_0", 4) shouldBe expected(("1", 2, AlwaysEdge)) - succOf("1", 2) shouldBe expected(("_tmp_0[1]", 1, AlwaysEdge)) - succOf("_tmp_0[1]", 1) shouldBe expected(("_tmp_0[1] === void 0 ? 2 : _tmp_0[1]", AlwaysEdge)) - succOf("_tmp_0[1] === void 0 ? 2 : _tmp_0[1]") shouldBe expected( + succOf("_tmp_0[1] === void 0") should contain theSameElementsAs expected( + ("2", TrueEdge), + ("_tmp_0", 4, FalseEdge) + ) + succOf("_tmp_0", 4) should contain theSameElementsAs expected(("1", 2, AlwaysEdge)) + succOf("1", 2) should contain theSameElementsAs expected(("_tmp_0[1]", 1, AlwaysEdge)) + succOf("_tmp_0[1]", 1) should contain theSameElementsAs expected( + ("_tmp_0[1] === void 0 ? 2 : _tmp_0[1]", AlwaysEdge) + ) + succOf("_tmp_0[1] === void 0 ? 2 : _tmp_0[1]") should contain theSameElementsAs expected( ("b = _tmp_0[1] === void 0 ? 2 : _tmp_0[1]", AlwaysEdge) ) - succOf("b = _tmp_0[1] === void 0 ? 2 : _tmp_0[1]") shouldBe expected(("_tmp_0", 5, AlwaysEdge)) - succOf("_tmp_0", 5) shouldBe expected(("var [a = 1, b = 2] = x", AlwaysEdge)) - succOf("var [a = 1, b = 2] = x") shouldBe expected(("RET", AlwaysEdge)) + succOf("b = _tmp_0[1] === void 0 ? 2 : _tmp_0[1]") should contain theSameElementsAs expected( + ("_tmp_0", 5, AlwaysEdge) + ) + succOf("_tmp_0", 5) should contain theSameElementsAs expected(("var [a = 1, b = 2] = x", AlwaysEdge)) + succOf("var [a = 1, b = 2] = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment with ignores" in { implicit val cpg: Cpg = code("var [a, , b] = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("_tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0]") shouldBe expected(("a = _tmp_0[0]", AlwaysEdge)) - - succOf("a = _tmp_0[0]") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("_tmp_0[2]", AlwaysEdge)) - succOf("_tmp_0[2]") shouldBe expected(("b = _tmp_0[2]", AlwaysEdge)) - succOf("b = _tmp_0[2]") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("var [a, , b] = x", AlwaysEdge)) - succOf("var [a, , b] = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("_tmp_0[0]", AlwaysEdge)) + succOf("_tmp_0[0]") should contain theSameElementsAs expected(("a = _tmp_0[0]", AlwaysEdge)) + + succOf("a = _tmp_0[0]") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("_tmp_0[2]", AlwaysEdge)) + succOf("_tmp_0[2]") should contain theSameElementsAs expected(("b = _tmp_0[2]", AlwaysEdge)) + succOf("b = _tmp_0[2]") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var [a, , b] = x", AlwaysEdge)) + succOf("var [a, , b] = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment with rest" in { implicit val cpg: Cpg = code("var [a, ...rest] = x") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0 = x", AlwaysEdge)) - - succOf("_tmp_0 = x") shouldBe expected(("a", AlwaysEdge)) - - succOf("a") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("_tmp_0[0]", AlwaysEdge)) - succOf("_tmp_0[0]") shouldBe expected(("a = _tmp_0[0]", AlwaysEdge)) - - succOf("a = _tmp_0[0]") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("_tmp_0[1]", AlwaysEdge)) - succOf("_tmp_0[1]") shouldBe expected(("rest", AlwaysEdge)) - succOf("rest") shouldBe expected(("...rest", AlwaysEdge)) - succOf("...rest") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("var [a, ...rest] = x", AlwaysEdge)) - succOf("var [a, ...rest] = x") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0 = x", AlwaysEdge)) + + succOf("_tmp_0 = x") should contain theSameElementsAs expected(("a", AlwaysEdge)) + + succOf("a") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("_tmp_0[0]", AlwaysEdge)) + succOf("_tmp_0[0]") should contain theSameElementsAs expected(("a = _tmp_0[0]", AlwaysEdge)) + + succOf("a = _tmp_0[0]") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("_tmp_0[1]", AlwaysEdge)) + succOf("_tmp_0[1]") should contain theSameElementsAs expected(("rest", AlwaysEdge)) + succOf("rest") should contain theSameElementsAs expected(("...rest", AlwaysEdge)) + succOf("...rest") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("var [a, ...rest] = x", AlwaysEdge)) + succOf("var [a, ...rest] = x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for array destruction assignment as parameter" in { @@ -373,25 +397,25 @@ class MixedCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCpg | return id |} |""".stripMargin) - succOf("userId", NodeTypes.METHOD) shouldBe expected(("id", AlwaysEdge)) - succOf("id") shouldBe expected(("param1_0", AlwaysEdge)) - succOf("param1_0") shouldBe expected(("id", 1, AlwaysEdge)) - succOf("id", 1) shouldBe expected(("param1_0.id", AlwaysEdge)) - succOf("param1_0.id") shouldBe expected(("id = param1_0.id", AlwaysEdge)) - succOf("id = param1_0.id") shouldBe expected(("id", 2, AlwaysEdge)) - succOf("id", 2) shouldBe expected(("return id", AlwaysEdge)) - succOf("return id") shouldBe expected(("RET", AlwaysEdge)) + succOf("userId", NodeTypes.METHOD) should contain theSameElementsAs expected(("id", AlwaysEdge)) + succOf("id") should contain theSameElementsAs expected(("param1_0", AlwaysEdge)) + succOf("param1_0") should contain theSameElementsAs expected(("id", 1, AlwaysEdge)) + succOf("id", 1) should contain theSameElementsAs expected(("param1_0.id", AlwaysEdge)) + succOf("param1_0.id") should contain theSameElementsAs expected(("id = param1_0.id", AlwaysEdge)) + succOf("id = param1_0.id") should contain theSameElementsAs expected(("id", 2, AlwaysEdge)) + succOf("id", 2) should contain theSameElementsAs expected(("return id", AlwaysEdge)) + succOf("return id") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "CFG generation for spread arguments" should { "have correct structure for method spread argument" in { implicit val cpg: Cpg = code("foo(...args)") - succOf(":program") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("args", AlwaysEdge)) - succOf("args") shouldBe expected(("...args", AlwaysEdge)) - succOf("...args") shouldBe expected(("foo(...args)", AlwaysEdge)) - succOf("foo(...args)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("args", AlwaysEdge)) + succOf("args") should contain theSameElementsAs expected(("...args", AlwaysEdge)) + succOf("...args") should contain theSameElementsAs expected(("foo(...args)", AlwaysEdge)) + succOf("foo(...args)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } @@ -400,110 +424,110 @@ class MixedCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCpg "CFG generation for await/async" should { "be correct for await/async" in { implicit val cpg: Cpg = code("async function x(foo) { await foo() }") - succOf("x", NodeTypes.METHOD) shouldBe expected(("foo", AlwaysEdge)) - succOf("foo", NodeTypes.IDENTIFIER) shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("foo()", AlwaysEdge)) - succOf("foo()") shouldBe expected(("await foo()", AlwaysEdge)) - succOf("await foo()") shouldBe expected(("RET", AlwaysEdge)) + succOf("x", NodeTypes.METHOD) should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("foo()", AlwaysEdge)) + succOf("foo()") should contain theSameElementsAs expected(("await foo()", AlwaysEdge)) + succOf("await foo()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } "CFG generation for instanceof/delete" should { "be correct for instanceof" in { implicit val cpg: Cpg = code("x instanceof Foo") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("Foo", AlwaysEdge)) - succOf("Foo") shouldBe expected(("x instanceof Foo", AlwaysEdge)) - succOf("x instanceof Foo", NodeTypes.CALL) shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("Foo", AlwaysEdge)) + succOf("Foo") should contain theSameElementsAs expected(("x instanceof Foo", AlwaysEdge)) + succOf("x instanceof Foo", NodeTypes.CALL) should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for delete" in { implicit val cpg: Cpg = code("delete foo.x") - succOf(":program") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("foo.x", AlwaysEdge)) - succOf("foo.x") shouldBe expected(("delete foo.x", AlwaysEdge)) - succOf("delete foo.x", NodeTypes.CALL) shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("foo.x", AlwaysEdge)) + succOf("foo.x") should contain theSameElementsAs expected(("delete foo.x", AlwaysEdge)) + succOf("delete foo.x", NodeTypes.CALL) should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } "CFG generation for default parameters" should { "be correct for method parameter with default" in { implicit val cpg: Cpg = code("function foo(a = 1) { }") - cpg.method.nameExact("foo").parameter.code.l shouldBe List("this", "a = 1") - - succOf("foo", NodeTypes.METHOD) shouldBe expected(("a", AlwaysEdge)) - succOf("a", NodeTypes.IDENTIFIER) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("a === void 0", AlwaysEdge)) - succOf("a === void 0") shouldBe expected(("1", TrueEdge), ("a", 2, FalseEdge)) - succOf("1") shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a", 2) shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a === void 0 ? 1 : a") shouldBe expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a = a === void 0 ? 1 : a") shouldBe expected(("RET", AlwaysEdge)) + cpg.method.nameExact("foo").parameter.code.l should contain theSameElementsAs List("this", "a = 1") + + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("a === void 0", AlwaysEdge)) + succOf("a === void 0") should contain theSameElementsAs expected(("1", TrueEdge), ("a", 2, FalseEdge)) + succOf("1") should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a", 2) should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a === void 0 ? 1 : a") should contain theSameElementsAs expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a = a === void 0 ? 1 : a") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for multiple method parameters with default" in { implicit val cpg: Cpg = code("function foo(a = 1, b = 2) { }") - cpg.method.nameExact("foo").parameter.code.l shouldBe List("this", "a = 1", "b = 2") - - succOf("foo", NodeTypes.METHOD) shouldBe expected(("a", AlwaysEdge)) - succOf("a", NodeTypes.IDENTIFIER) shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("a === void 0", AlwaysEdge)) - succOf("a === void 0") shouldBe expected(("1", TrueEdge), ("a", 2, FalseEdge)) - succOf("1") shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a", 2) shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a === void 0 ? 1 : a") shouldBe expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a = a === void 0 ? 1 : a") shouldBe expected(("b", AlwaysEdge)) - - succOf("b", NodeTypes.IDENTIFIER) shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("void 0", 1, AlwaysEdge)) - succOf("void 0", 1) shouldBe expected(("b === void 0", AlwaysEdge)) - succOf("b === void 0") shouldBe expected(("2", TrueEdge), ("b", 2, FalseEdge)) - succOf("2") shouldBe expected(("b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b", 2) shouldBe expected(("b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b === void 0 ? 2 : b") shouldBe expected(("b = b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b = b === void 0 ? 2 : b") shouldBe expected(("RET", AlwaysEdge)) + cpg.method.nameExact("foo").parameter.code.l should contain theSameElementsAs List("this", "a = 1", "b = 2") + + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("a === void 0", AlwaysEdge)) + succOf("a === void 0") should contain theSameElementsAs expected(("1", TrueEdge), ("a", 2, FalseEdge)) + succOf("1") should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a", 2) should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a === void 0 ? 1 : a") should contain theSameElementsAs expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a = a === void 0 ? 1 : a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + + succOf("b", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("void 0", 1, AlwaysEdge)) + succOf("void 0", 1) should contain theSameElementsAs expected(("b === void 0", AlwaysEdge)) + succOf("b === void 0") should contain theSameElementsAs expected(("2", TrueEdge), ("b", 2, FalseEdge)) + succOf("2") should contain theSameElementsAs expected(("b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b", 2) should contain theSameElementsAs expected(("b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b === void 0 ? 2 : b") should contain theSameElementsAs expected(("b = b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b = b === void 0 ? 2 : b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for method mixed parameters with default" in { implicit val cpg: Cpg = code("function foo(a, b = 1) { }") - cpg.method.nameExact("foo").parameter.code.l shouldBe List("this", "a", "b = 1") - - succOf("foo", NodeTypes.METHOD) shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("b === void 0", AlwaysEdge)) - succOf("b === void 0") shouldBe expected(("1", TrueEdge), ("b", 2, FalseEdge)) - succOf("1") shouldBe expected(("b === void 0 ? 1 : b", AlwaysEdge)) - succOf("b", 2) shouldBe expected(("b === void 0 ? 1 : b", AlwaysEdge)) - succOf("b === void 0 ? 1 : b") shouldBe expected(("b = b === void 0 ? 1 : b", AlwaysEdge)) - succOf("b = b === void 0 ? 1 : b") shouldBe expected(("RET", AlwaysEdge)) + cpg.method.nameExact("foo").parameter.code.l should contain theSameElementsAs List("this", "a", "b = 1") + + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("b === void 0", AlwaysEdge)) + succOf("b === void 0") should contain theSameElementsAs expected(("1", TrueEdge), ("b", 2, FalseEdge)) + succOf("1") should contain theSameElementsAs expected(("b === void 0 ? 1 : b", AlwaysEdge)) + succOf("b", 2) should contain theSameElementsAs expected(("b === void 0 ? 1 : b", AlwaysEdge)) + succOf("b === void 0 ? 1 : b") should contain theSameElementsAs expected(("b = b === void 0 ? 1 : b", AlwaysEdge)) + succOf("b = b === void 0 ? 1 : b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for multiple method mixed parameters with default" in { implicit val cpg: Cpg = code("function foo(x, a = 1, b = 2) { }") - cpg.method.nameExact("foo").parameter.code.l shouldBe List("this", "x", "a = 1", "b = 2") - - succOf("foo", NodeTypes.METHOD) shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("a", 1, AlwaysEdge)) - succOf("a", 1) shouldBe expected(("void 0", AlwaysEdge)) - succOf("void 0") shouldBe expected(("a === void 0", AlwaysEdge)) - succOf("a === void 0") shouldBe expected(("1", TrueEdge), ("a", 2, FalseEdge)) - succOf("1") shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a", 2) shouldBe expected(("a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a === void 0 ? 1 : a") shouldBe expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) - succOf("a = a === void 0 ? 1 : a") shouldBe expected(("b", AlwaysEdge)) - - succOf("b") shouldBe expected(("b", 1, AlwaysEdge)) - succOf("b", 1) shouldBe expected(("void 0", 1, AlwaysEdge)) - succOf("void 0", 1) shouldBe expected(("b === void 0", AlwaysEdge)) - succOf("b === void 0") shouldBe expected(("2", TrueEdge), ("b", 2, FalseEdge)) - succOf("2") shouldBe expected(("b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b", 2) shouldBe expected(("b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b === void 0 ? 2 : b") shouldBe expected(("b = b === void 0 ? 2 : b", AlwaysEdge)) - succOf("b = b === void 0 ? 2 : b") shouldBe expected(("RET", AlwaysEdge)) + cpg.method.nameExact("foo").parameter.code.l should contain theSameElementsAs List("this", "x", "a = 1", "b = 2") + + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("a", 1, AlwaysEdge)) + succOf("a", 1) should contain theSameElementsAs expected(("void 0", AlwaysEdge)) + succOf("void 0") should contain theSameElementsAs expected(("a === void 0", AlwaysEdge)) + succOf("a === void 0") should contain theSameElementsAs expected(("1", TrueEdge), ("a", 2, FalseEdge)) + succOf("1") should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a", 2) should contain theSameElementsAs expected(("a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a === void 0 ? 1 : a") should contain theSameElementsAs expected(("a = a === void 0 ? 1 : a", AlwaysEdge)) + succOf("a = a === void 0 ? 1 : a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + + succOf("b") should contain theSameElementsAs expected(("b", 1, AlwaysEdge)) + succOf("b", 1) should contain theSameElementsAs expected(("void 0", 1, AlwaysEdge)) + succOf("void 0", 1) should contain theSameElementsAs expected(("b === void 0", AlwaysEdge)) + succOf("b === void 0") should contain theSameElementsAs expected(("2", TrueEdge), ("b", 2, FalseEdge)) + succOf("2") should contain theSameElementsAs expected(("b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b", 2) should contain theSameElementsAs expected(("b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b === void 0 ? 2 : b") should contain theSameElementsAs expected(("b = b === void 0 ? 2 : b", AlwaysEdge)) + succOf("b = b === void 0 ? 2 : b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/SimpleCfgCreationPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/SimpleCfgCreationPassTests.scala index 9a290cbb5931..04aaa2121da7 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/SimpleCfgCreationPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/cfg/SimpleCfgCreationPassTests.scala @@ -11,85 +11,113 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp "CFG generation for simple fragments" should { "have correct structure for block expression" in { implicit val cpg: Cpg = code("let x = (class Foo {}, bar())") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("class Foo", AlwaysEdge)) - succOf("class Foo") shouldBe expected(("bar", AlwaysEdge)) - succOf("bar") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("bar()", AlwaysEdge)) - succOf("bar()") shouldBe expected(("class Foo {}, bar()", AlwaysEdge)) - succOf("class Foo {}, bar()") shouldBe expected(("let x = (class Foo {}, bar())", AlwaysEdge)) - succOf("let x = (class Foo {}, bar())") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("class Foo", AlwaysEdge)) + succOf("class Foo") should contain theSameElementsAs expected(("bar", AlwaysEdge)) + succOf("bar") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("bar()", AlwaysEdge)) + succOf("bar()") should contain theSameElementsAs expected(("class Foo {}, bar()", AlwaysEdge)) + succOf("class Foo {}, bar()") should contain theSameElementsAs expected( + ("let x = (class Foo {}, bar())", AlwaysEdge) + ) + succOf("let x = (class Foo {}, bar())") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "have correct structure for empty array literal" in { implicit val cpg: Cpg = code("var x = []") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("__ecma.Array.factory()", AlwaysEdge)) - succOf("__ecma.Array.factory()") shouldBe expected(("var x = []", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("__ecma.Array.factory()", AlwaysEdge)) + succOf("__ecma.Array.factory()") should contain theSameElementsAs expected(("var x = []", AlwaysEdge)) } "have correct structure for array literal with values" in { implicit val cpg: Cpg = code("var x = [1, 2]") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("__ecma.Array.factory()", AlwaysEdge)) - succOf("__ecma.Array.factory()") shouldBe expected(("_tmp_0 = __ecma.Array.factory()", AlwaysEdge)) - - succOf("_tmp_0 = __ecma.Array.factory()") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("push", AlwaysEdge)) - succOf("push") shouldBe expected(("_tmp_0.push", AlwaysEdge)) - succOf("_tmp_0.push") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("_tmp_0.push(1)", AlwaysEdge)) - - succOf("_tmp_0.push(1)") shouldBe expected(("_tmp_0", 3, AlwaysEdge)) - succOf("_tmp_0", 3) shouldBe expected(("push", 1, AlwaysEdge)) - succOf("push", 1) shouldBe expected(("_tmp_0.push", 1, AlwaysEdge)) - succOf("_tmp_0.push", 1) shouldBe expected(("_tmp_0", 4, AlwaysEdge)) - succOf("_tmp_0", 4) shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("_tmp_0.push(2)", AlwaysEdge)) - - succOf("_tmp_0.push(2)") shouldBe expected(("_tmp_0", 5, AlwaysEdge)) - succOf("_tmp_0", 5) shouldBe expected(("[1, 2]", AlwaysEdge)) - succOf("[1, 2]") shouldBe expected(("var x = [1, 2]", AlwaysEdge)) - succOf("var x = [1, 2]") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("__ecma.Array.factory()", AlwaysEdge)) + succOf("__ecma.Array.factory()") should contain theSameElementsAs expected( + ("_tmp_0 = __ecma.Array.factory()", AlwaysEdge) + ) + + succOf("_tmp_0 = __ecma.Array.factory()") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("push", AlwaysEdge)) + succOf("push") should contain theSameElementsAs expected(("_tmp_0.push", AlwaysEdge)) + succOf("_tmp_0.push") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("_tmp_0.push(1)", AlwaysEdge)) + + succOf("_tmp_0.push(1)") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("push", 1, AlwaysEdge)) + succOf("push", 1) should contain theSameElementsAs expected(("_tmp_0.push", 1, AlwaysEdge)) + succOf("_tmp_0.push", 1) should contain theSameElementsAs expected(("_tmp_0", 4, AlwaysEdge)) + succOf("_tmp_0", 4) should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("_tmp_0.push(2)", AlwaysEdge)) + + succOf("_tmp_0.push(2)") should contain theSameElementsAs expected(("_tmp_0", 5, AlwaysEdge)) + succOf("_tmp_0", 5) should contain theSameElementsAs expected(("[1, 2]", AlwaysEdge)) + succOf("[1, 2]") should contain theSameElementsAs expected(("var x = [1, 2]", AlwaysEdge)) + succOf("var x = [1, 2]") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "have correct structure for untagged runtime node in call" in { implicit val cpg: Cpg = code(s"foo(`Hello $${world}!`)") - succOf(":program") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("\"Hello \"", AlwaysEdge)) - succOf("\"Hello \"") shouldBe expected(("world", AlwaysEdge)) - succOf("world") shouldBe expected(("\"!\"", AlwaysEdge)) - succOf("\"!\"") shouldBe expected((s"${Operators.formatString}(\"Hello \", world, \"!\")", AlwaysEdge)) - succOf(s"${Operators.formatString}(\"Hello \", world, \"!\")") shouldBe expected( + succOf(":program") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("\"Hello \"", AlwaysEdge)) + succOf("\"Hello \"") should contain theSameElementsAs expected(("world", AlwaysEdge)) + succOf("world") should contain theSameElementsAs expected(("\"!\"", AlwaysEdge)) + succOf("\"!\"") should contain theSameElementsAs expected( + (s"${Operators.formatString}(\"Hello \", world, \"!\")", AlwaysEdge) + ) + succOf(s"${Operators.formatString}(\"Hello \", world, \"!\")") should contain theSameElementsAs expected( (s"foo(`Hello $${world}!`)", AlwaysEdge) ) - succOf(s"foo(`Hello $${world}!`)") shouldBe expected(("RET", AlwaysEdge)) + succOf(s"foo(`Hello $${world}!`)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "have correct structure for untagged runtime node" in { implicit val cpg: Cpg = code(s"`$${x + 1}`") - succOf(":program") shouldBe expected(("\"\"", AlwaysEdge)) - succOf("\"\"") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x + 1", AlwaysEdge)) - succOf("x + 1") shouldBe expected(("\"\"", 1, AlwaysEdge)) - succOf("\"\"", 1) shouldBe expected((s"${Operators.formatString}(\"\", x + 1, \"\")", AlwaysEdge)) - succOf(s"${Operators.formatString}(\"\", x + 1, \"\")") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("\"\"", AlwaysEdge)) + succOf("\"\"") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x + 1", AlwaysEdge)) + succOf("x + 1") should contain theSameElementsAs expected(("\"\"", 1, AlwaysEdge)) + succOf("\"\"", 1) should contain theSameElementsAs expected( + (s"${Operators.formatString}(\"\", x + 1, \"\")", AlwaysEdge) + ) + succOf(s"${Operators.formatString}(\"\", x + 1, \"\")") should contain theSameElementsAs expected( + ("RET", AlwaysEdge) + ) } "have correct structure for tagged runtime node" in { implicit val cpg: Cpg = code(s"String.raw`../$${42}\\..`") - succOf(":program") shouldBe expected(("\"../\"", AlwaysEdge)) - succOf("\"../\"") shouldBe expected(("42", AlwaysEdge)) - succOf("42") shouldBe expected(("\"\\..\"", AlwaysEdge)) - succOf("\"\\..\"") shouldBe expected((s"${Operators.formatString}(\"../\", 42, \"\\..\")", AlwaysEdge)) - succOf(s"${Operators.formatString}(\"../\", 42, \"\\..\")") shouldBe expected( - (s"String.raw(${Operators.formatString}(\"../\", 42, \"\\..\"))", AlwaysEdge) + succOf(":program") should contain theSameElementsAs expected(("String", AlwaysEdge)) + succOf("String") should contain theSameElementsAs expected(("raw", AlwaysEdge)) + succOf("raw") should contain theSameElementsAs expected(("String.raw", AlwaysEdge)) + succOf("String.raw") should contain theSameElementsAs expected(("String", 1, AlwaysEdge)) + succOf("String", 1) should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("__ecma.Array.factory()", AlwaysEdge)) + succOf("__ecma.Array.factory()") should contain theSameElementsAs expected( + ("_tmp_0 = __ecma.Array.factory()", AlwaysEdge) ) - succOf(s"String.raw(${Operators.formatString}(\"../\", 42, \"\\..\"))") shouldBe expected(("RET", AlwaysEdge)) + succOf("_tmp_0 = __ecma.Array.factory()") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("push", AlwaysEdge)) + succOf("push") should contain theSameElementsAs expected(("_tmp_0.push", AlwaysEdge)) + succOf("_tmp_0.push") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("\"../\"", AlwaysEdge)) + succOf("\"../\"") should contain theSameElementsAs expected(("_tmp_0.push(\"../\")", AlwaysEdge)) + succOf("_tmp_0.push(\"../\")") should contain theSameElementsAs expected(("_tmp_0", 3, AlwaysEdge)) + succOf("_tmp_0", 3) should contain theSameElementsAs expected(("push", 1, AlwaysEdge)) + succOf("push", 1) should contain theSameElementsAs expected(("_tmp_0.push", 1, AlwaysEdge)) + succOf("_tmp_0.push", 1) should contain theSameElementsAs expected(("_tmp_0", 4, AlwaysEdge)) + succOf("_tmp_0", 4) should contain theSameElementsAs expected(("\"\\..\"", AlwaysEdge)) + succOf("\"\\..\"") should contain theSameElementsAs expected(("_tmp_0.push(\"\\..\")", AlwaysEdge)) + succOf("_tmp_0.push(\"\\..\")") should contain theSameElementsAs expected(("_tmp_0", 5, AlwaysEdge)) + succOf("_tmp_0", 5) should contain theSameElementsAs expected(("`../${42}\\..`", AlwaysEdge)) + succOf("`../${42}\\..`") should contain theSameElementsAs expected(("42", AlwaysEdge)) + succOf("42") should contain theSameElementsAs expected((s"String.raw`../$${42}\\..`", AlwaysEdge)) + succOf(s"String.raw`../$${42}\\..`") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for try" in { @@ -102,13 +130,13 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp | close() |} |""".stripMargin) - succOf(":program") shouldBe expected(("open", AlwaysEdge)) - succOf("open") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("open()", AlwaysEdge)) - succOf("open()") shouldBe expected(("err", AlwaysEdge), ("close", AlwaysEdge)) - succOf("err") shouldBe expected(("handle", AlwaysEdge)) - succOf("handle()") shouldBe expected(("close", AlwaysEdge)) - succOf("close()") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("open", AlwaysEdge)) + succOf("open") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("open()", AlwaysEdge)) + succOf("open()") should contain theSameElementsAs expected(("err", AlwaysEdge), ("close", AlwaysEdge)) + succOf("err") should contain theSameElementsAs expected(("handle", AlwaysEdge)) + succOf("handle()") should contain theSameElementsAs expected(("close", AlwaysEdge)) + succOf("close()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for try with multiple CFG exit nodes in try block" in { @@ -125,14 +153,14 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp | close() |} |""".stripMargin) - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("doA", TrueEdge), ("doB", FalseEdge)) - succOf("doA()") shouldBe expected(("err", AlwaysEdge), ("close", AlwaysEdge)) - succOf("err") shouldBe expected(("handle", AlwaysEdge)) - succOf("doB()") shouldBe expected(("err", AlwaysEdge), ("close", AlwaysEdge)) - succOf("err") shouldBe expected(("handle", AlwaysEdge)) - succOf("handle()") shouldBe expected(("close", AlwaysEdge)) - succOf("close()") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("doA", TrueEdge), ("doB", FalseEdge)) + succOf("doA()") should contain theSameElementsAs expected(("err", AlwaysEdge), ("close", AlwaysEdge)) + succOf("err") should contain theSameElementsAs expected(("handle", AlwaysEdge)) + succOf("doB()") should contain theSameElementsAs expected(("err", AlwaysEdge), ("close", AlwaysEdge)) + succOf("err") should contain theSameElementsAs expected(("handle", AlwaysEdge)) + succOf("handle()") should contain theSameElementsAs expected(("close", AlwaysEdge)) + succOf("close()") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for 1 object with simple values" in { @@ -142,133 +170,135 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp | key2: 2 |} |""".stripMargin) - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("key1", AlwaysEdge)) - succOf("key1") shouldBe expected(("_tmp_0.key1", AlwaysEdge)) - succOf("_tmp_0.key1") shouldBe expected(("\"value\"", AlwaysEdge)) - succOf("\"value\"") shouldBe expected(("_tmp_0.key1 = \"value\"", AlwaysEdge)) - - succOf("_tmp_0.key1 = \"value\"") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("key2", AlwaysEdge)) - succOf("key2") shouldBe expected(("_tmp_0.key2", AlwaysEdge)) - succOf("_tmp_0.key2") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("_tmp_0.key2 = 2", AlwaysEdge)) - - succOf("_tmp_0.key2 = 2") shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("{\n key1: \"value\",\n key2: 2\n}", AlwaysEdge)) - succOf("{\n key1: \"value\",\n key2: 2\n}") shouldBe expected( + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("key1", AlwaysEdge)) + succOf("key1") should contain theSameElementsAs expected(("_tmp_0.key1", AlwaysEdge)) + succOf("_tmp_0.key1") should contain theSameElementsAs expected(("\"value\"", AlwaysEdge)) + succOf("\"value\"") should contain theSameElementsAs expected(("_tmp_0.key1 = \"value\"", AlwaysEdge)) + + succOf("_tmp_0.key1 = \"value\"") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("key2", AlwaysEdge)) + succOf("key2") should contain theSameElementsAs expected(("_tmp_0.key2", AlwaysEdge)) + succOf("_tmp_0.key2") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("_tmp_0.key2 = 2", AlwaysEdge)) + + succOf("_tmp_0.key2 = 2") should contain theSameElementsAs expected(("_tmp_0", 2, AlwaysEdge)) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("{\n key1: \"value\",\n key2: 2\n}", AlwaysEdge)) + succOf("{\n key1: \"value\",\n key2: 2\n}") should contain theSameElementsAs expected( ("var x = {\n key1: \"value\",\n key2: 2\n}", AlwaysEdge) ) - succOf("var x = {\n key1: \"value\",\n key2: 2\n}") shouldBe expected(("RET", AlwaysEdge)) + succOf("var x = {\n key1: \"value\",\n key2: 2\n}") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for member access used in an assignment (chained)" in { implicit val cpg: Cpg = code("a.b = c.z;") - succOf(":program") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("a.b", AlwaysEdge)) - succOf("a.b") shouldBe expected(("c", AlwaysEdge)) - succOf("c") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("c.z", AlwaysEdge)) - succOf("c.z") shouldBe expected(("a.b = c.z", AlwaysEdge)) - succOf("a.b = c.z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("a.b", AlwaysEdge)) + succOf("a.b") should contain theSameElementsAs expected(("c", AlwaysEdge)) + succOf("c") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("c.z", AlwaysEdge)) + succOf("c.z") should contain theSameElementsAs expected(("a.b = c.z", AlwaysEdge)) + succOf("a.b = c.z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for decl statement with assignment" in { implicit val cpg: Cpg = code("var x = 1;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("var x = 1", AlwaysEdge)) - succOf("var x = 1") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("var x = 1", AlwaysEdge)) + succOf("var x = 1") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested expression" in { implicit val cpg: Cpg = code("x = y + 1;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y + 1", AlwaysEdge)) - succOf("y + 1") shouldBe expected(("x = y + 1", AlwaysEdge)) - succOf("x = y + 1") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y + 1", AlwaysEdge)) + succOf("y + 1") should contain theSameElementsAs expected(("x = y + 1", AlwaysEdge)) + succOf("x = y + 1") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for return statement" in { implicit val cpg: Cpg = code("function foo(x) { return x; }") - succOf("foo", NodeTypes.METHOD) shouldBe expected(("x", AlwaysEdge)) - succOf("x", NodeTypes.IDENTIFIER) shouldBe expected(("return x", AlwaysEdge)) - succOf("return x") shouldBe expected(("RET", AlwaysEdge)) + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("return x", AlwaysEdge)) + succOf("return x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for consecutive return statements" in { implicit val cpg: Cpg = code("function foo(x, y) { return x; return y; }") - succOf("foo", NodeTypes.METHOD) shouldBe expected(("x", AlwaysEdge)) - succOf("x", NodeTypes.IDENTIFIER) shouldBe expected(("return x", AlwaysEdge)) - succOf("y", NodeTypes.IDENTIFIER) shouldBe expected(("return y", AlwaysEdge)) - succOf("return x") shouldBe expected(("RET", AlwaysEdge)) - succOf("return y") shouldBe expected(("RET", AlwaysEdge)) + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("return x", AlwaysEdge)) + succOf("y", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("return y", AlwaysEdge)) + succOf("return x") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("return y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for outer program function which declares foo function object" in { implicit val cpg: Cpg = code("function foo(x, y) { return; }") - succOf(":program", NodeTypes.METHOD) shouldBe expected(("foo", 2, AlwaysEdge)) - succOf("foo", NodeTypes.IDENTIFIER) shouldBe expected(("foo", 3, AlwaysEdge)) - succOf("foo", NodeTypes.METHOD_REF) shouldBe expected( + succOf(":program", NodeTypes.METHOD) should contain theSameElementsAs expected(("foo", 2, AlwaysEdge)) + succOf("foo", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("foo", 3, AlwaysEdge)) + succOf("foo", NodeTypes.METHOD_REF) should contain theSameElementsAs expected( ("function foo = function foo(x, y) { return; }", AlwaysEdge) ) - succOf("function foo = function foo(x, y) { return; }") shouldBe expected(("RET", AlwaysEdge)) + succOf("function foo = function foo(x, y) { return; }") should contain theSameElementsAs expected( + ("RET", AlwaysEdge) + ) } "be correct for void return statement" in { implicit val cpg: Cpg = code("function foo() { return; }") - succOf("foo", NodeTypes.METHOD) shouldBe expected(("return", AlwaysEdge)) - succOf("return") shouldBe expected(("RET", AlwaysEdge)) + succOf("foo", NodeTypes.METHOD) should contain theSameElementsAs expected(("return", AlwaysEdge)) + succOf("return") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for call expression" in { implicit val cpg: Cpg = code("foo(a + 1, b);") - succOf(":program") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("this", AlwaysEdge)) - succOf("this", NodeTypes.IDENTIFIER) shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("a + 1", AlwaysEdge)) - succOf("a + 1") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("foo(a + 1, b)", AlwaysEdge)) - succOf("foo(a + 1, b)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("this", AlwaysEdge)) + succOf("this", NodeTypes.IDENTIFIER) should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("a + 1", AlwaysEdge)) + succOf("a + 1") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("foo(a + 1, b)", AlwaysEdge)) + succOf("foo(a + 1, b)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for chained calls" in { implicit val cpg: Cpg = code("x.foo(y).bar(z)") - succOf(":program") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("x.foo", AlwaysEdge)) - succOf("x.foo") shouldBe expected(("x", 1, AlwaysEdge)) - succOf("x", 1) shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("x.foo(y)", AlwaysEdge)) - succOf("x.foo(y)") shouldBe expected(("(_tmp_0 = x.foo(y))", AlwaysEdge)) - succOf("(_tmp_0 = x.foo(y))") shouldBe expected(("bar", AlwaysEdge)) - succOf("bar") shouldBe expected(("(_tmp_0 = x.foo(y)).bar", AlwaysEdge)) - succOf("(_tmp_0 = x.foo(y)).bar") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("x.foo(y).bar(z)", AlwaysEdge)) - succOf("x.foo(y).bar(z)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("x.foo", AlwaysEdge)) + succOf("x.foo") should contain theSameElementsAs expected(("x", 1, AlwaysEdge)) + succOf("x", 1) should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x.foo(y)", AlwaysEdge)) + succOf("x.foo(y)") should contain theSameElementsAs expected(("(_tmp_0 = x.foo(y))", AlwaysEdge)) + succOf("(_tmp_0 = x.foo(y))") should contain theSameElementsAs expected(("bar", AlwaysEdge)) + succOf("bar") should contain theSameElementsAs expected(("(_tmp_0 = x.foo(y)).bar", AlwaysEdge)) + succOf("(_tmp_0 = x.foo(y)).bar") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("x.foo(y).bar(z)", AlwaysEdge)) + succOf("x.foo(y).bar(z)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for unary expression '++'" in { implicit val cpg: Cpg = code("x++") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("x++", AlwaysEdge)) - succOf("x++") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("x++", AlwaysEdge)) + succOf("x++") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for conditional expression" in { implicit val cpg: Cpg = code("x ? y : z;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) - succOf("y") shouldBe expected(("x ? y : z", AlwaysEdge)) - succOf("z") shouldBe expected(("x ? y : z", AlwaysEdge)) - succOf("x ? y : z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("x ? y : z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("x ? y : z", AlwaysEdge)) + succOf("x ? y : z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for labeled expressions with continue" in { @@ -283,98 +313,101 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp | } |} |""".stripMargin) - succOf(":program") shouldBe expected(("var i, j;", AlwaysEdge)) - succOf("loop1:") shouldBe expected(("i", AlwaysEdge)) - succOf("i") shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("i = 0", AlwaysEdge)) - succOf("i = 0") shouldBe expected(("i", 1, AlwaysEdge)) - succOf("i", 1) shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("i < 3", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("var i, j;", AlwaysEdge)) + succOf("loop1:") should contain theSameElementsAs expected(("i", AlwaysEdge)) + succOf("i") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("i = 0", AlwaysEdge)) + succOf("i = 0") should contain theSameElementsAs expected(("i", 1, AlwaysEdge)) + succOf("i", 1) should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("i < 3", AlwaysEdge)) import io.shiftleft.semanticcpg.language._ val codeStr = cpg.method.ast.code(".*loop1:.*").code.head - succOf("i < 3") shouldBe expected(("loop2:", AlwaysEdge), (codeStr, AlwaysEdge)) - succOf(codeStr) shouldBe expected(("RET", AlwaysEdge)) + succOf("i < 3") should contain theSameElementsAs expected(("loop2:", AlwaysEdge), (codeStr, AlwaysEdge)) + succOf(codeStr) should contain theSameElementsAs expected(("RET", AlwaysEdge)) - succOf("loop2:") shouldBe expected(("j", AlwaysEdge)) - succOf("j") shouldBe expected(("0", 1, AlwaysEdge)) - succOf("0", 1) shouldBe expected(("j = 0", AlwaysEdge)) - succOf("j = 0") shouldBe expected(("j", 1, AlwaysEdge)) - succOf("j", 1) shouldBe expected(("3", 1, AlwaysEdge)) - succOf("3", 1) shouldBe expected(("j < 3", AlwaysEdge)) + succOf("loop2:") should contain theSameElementsAs expected(("j", AlwaysEdge)) + succOf("j") should contain theSameElementsAs expected(("0", 1, AlwaysEdge)) + succOf("0", 1) should contain theSameElementsAs expected(("j = 0", AlwaysEdge)) + succOf("j = 0") should contain theSameElementsAs expected(("j", 1, AlwaysEdge)) + succOf("j", 1) should contain theSameElementsAs expected(("3", 1, AlwaysEdge)) + succOf("3", 1) should contain theSameElementsAs expected(("j < 3", AlwaysEdge)) val code2 = cpg.method.ast.isBlock.code("loop2: for.*").code.head - succOf("j < 3") shouldBe expected((code2, AlwaysEdge), ("i", 2, AlwaysEdge)) - succOf(code2) shouldBe expected(("i", 2, AlwaysEdge)) - - succOf("i", 2) shouldBe expected(("i++", AlwaysEdge)) - succOf("i++") shouldBe expected(("i", 3, AlwaysEdge)) - succOf("i", 3) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("i === 1", AlwaysEdge)) - succOf("i === 1") shouldBe expected(("j", AlwaysEdge), ("i === 1 && j === 1", AlwaysEdge)) - succOf("i === 1 && j === 1") shouldBe expected(("continue loop1;", AlwaysEdge), ("console", AlwaysEdge)) - succOf("continue loop1;") shouldBe expected(("loop1:", AlwaysEdge)) - succOf("console") shouldBe expected(("log", AlwaysEdge)) - succOf("log") shouldBe expected(("console.log", AlwaysEdge)) + succOf("j < 3") should contain theSameElementsAs expected((code2, AlwaysEdge), ("i", 2, AlwaysEdge)) + succOf(code2) should contain theSameElementsAs expected(("i", 2, AlwaysEdge)) + + succOf("i", 2) should contain theSameElementsAs expected(("i++", AlwaysEdge)) + succOf("i++") should contain theSameElementsAs expected(("i", 3, AlwaysEdge)) + succOf("i", 3) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("i === 1", AlwaysEdge)) + succOf("i === 1") should contain theSameElementsAs expected(("j", AlwaysEdge), ("i === 1 && j === 1", AlwaysEdge)) + succOf("i === 1 && j === 1") should contain theSameElementsAs expected( + ("continue loop1;", AlwaysEdge), + ("console", AlwaysEdge) + ) + succOf("continue loop1;") should contain theSameElementsAs expected(("loop1:", AlwaysEdge)) + succOf("console") should contain theSameElementsAs expected(("log", AlwaysEdge)) + succOf("log") should contain theSameElementsAs expected(("console.log", AlwaysEdge)) } "be correct for plain while loop" in { implicit val cpg: Cpg = code("while (x < 1) { y = 2; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) - succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y = 2", AlwaysEdge)) + succOf("y = 2") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) } "be correct for plain while loop with break" in { implicit val cpg: Cpg = code("while (x < 1) { break; y; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct for plain while loop with continue" in { implicit val cpg: Cpg = code("while (x < 1) { continue; y; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) - succOf("continue;") shouldBe expected(("x", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("continue;") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) } "be correct for nested while loop" in { implicit val cpg: Cpg = code("while (x) {while(y) {z;}}") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("z", TrueEdge), ("x", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("z", TrueEdge), ("x", FalseEdge)) } "be correct for nested while loop with break" in { implicit val cpg: Cpg = code("while (x) { while(y) { break; z;} a;} b;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("b", FalseEdge)) - succOf("y") shouldBe expected(("break;", TrueEdge), ("a", FalseEdge)) - succOf("a") shouldBe expected(("x", AlwaysEdge)) - succOf("b") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("b", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", TrueEdge), ("a", FalseEdge)) + succOf("a") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for another nested while loop with break" in { implicit val cpg: Cpg = code("while (x) { while(y) { break; z;} a; break; b; } c;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("c", FalseEdge)) - succOf("y") shouldBe expected(("break;", TrueEdge), ("a", FalseEdge)) - succOf("break;") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("break;", 1, AlwaysEdge)) - succOf("break;", 1) shouldBe expected(("c", AlwaysEdge)) - succOf("c") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("c", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", TrueEdge), ("a", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("break;", 1, AlwaysEdge)) + succOf("break;", 1) should contain theSameElementsAs expected(("c", AlwaysEdge)) + succOf("c") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "nested while loop with conditional break" in { @@ -388,134 +421,134 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp | } |} """.stripMargin) - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("break;", TrueEdge), ("z", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("break;", 1) shouldBe expected(("x", AlwaysEdge)) - succOf("z") shouldBe expected(("break;", 1, TrueEdge), ("x", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", TrueEdge), ("z", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("break;", 1) should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("break;", 1, TrueEdge), ("x", FalseEdge)) } // DO-WHILE Loops "be correct for plain do-while loop" in { implicit val cpg: Cpg = code("do { y = 2; } while (x < 1);") - succOf(":program") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) - succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y = 2", AlwaysEdge)) + succOf("y = 2") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) } "be correct for plain do-while loop with break" in { implicit val cpg: Cpg = code("do { break; y; } while (x < 1);") - succOf(":program") shouldBe expected(("break;", AlwaysEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("break;", AlwaysEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) } "be correct for plain do-while loop with continue" in { implicit val cpg: Cpg = code("do { continue; y; } while (x < 1);") - succOf(":program") shouldBe expected(("continue;", AlwaysEdge)) - succOf("continue;") shouldBe expected(("x", AlwaysEdge)) - succOf("y") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("continue;", AlwaysEdge)) + succOf("continue;") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) } "be correct for nested do-while loop with continue" in { implicit val cpg: Cpg = code("do { do { x; } while (y); } while (z);") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("x", TrueEdge), ("z", FalseEdge)) - succOf("z") shouldBe expected(("x", TrueEdge), ("RET", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("x", TrueEdge), ("z", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("x", TrueEdge), ("RET", FalseEdge)) } "be correct for nested while/do-while loops with break" in { implicit val cpg: Cpg = code("while (x) { do { while(y) { break; a; } z; } while (x < 1); } c;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("c", FalseEdge)) - succOf("y") shouldBe expected(("break;", TrueEdge), ("z", FalseEdge)) - succOf("break;") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("x", 1, AlwaysEdge)) - succOf("x", 1) shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) - succOf("x < 1") shouldBe expected(("y", TrueEdge), ("x", FalseEdge)) - succOf("c") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("c", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", TrueEdge), ("z", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("x", 1, AlwaysEdge)) + succOf("x", 1) should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("x < 1", AlwaysEdge)) + succOf("x < 1") should contain theSameElementsAs expected(("y", TrueEdge), ("x", FalseEdge)) + succOf("c") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested while/do-while loops with break and continue" in { implicit val cpg: Cpg = code("while(x) { do { break; } while (y) } o;") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("break;", TrueEdge), ("o", FalseEdge)) - succOf("break;") shouldBe expected(("x", AlwaysEdge)) - succOf("o") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("break;", TrueEdge), ("o", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("o") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for two nested while loop with inner break" in { implicit val cpg: Cpg = code("while(y) { while(z) { break; x; } }") - succOf(":program") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("z", TrueEdge), ("RET", FalseEdge)) - succOf("z") shouldBe expected(("break;", TrueEdge), ("y", FalseEdge)) - succOf("break;") shouldBe expected(("y", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("z", TrueEdge), ("RET", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("break;", TrueEdge), ("y", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("y", AlwaysEdge)) } // FOR Loops "be correct for plain for-loop" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { a = 3; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("a", TrueEdge), ("RET", FalseEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("x = 0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("a", TrueEdge), ("RET", FalseEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct for plain for-loop with break" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { break; a = 3; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct for plain for-loop with continue" in { implicit val cpg: Cpg = code("for (x = 0; y < 1; z += 2) { continue; a = 3; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("0", AlwaysEdge)) - succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) - succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) - succOf("y < 1") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) - succOf("continue;") shouldBe expected(("z", AlwaysEdge)) - succOf("a") shouldBe expected(("3", AlwaysEdge)) - succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) - succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) - succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("0", AlwaysEdge)) + succOf("0") should contain theSameElementsAs expected(("x = 0", AlwaysEdge)) + succOf("x = 0") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y < 1", AlwaysEdge)) + succOf("y < 1") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("continue;") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("3", AlwaysEdge)) + succOf("3") should contain theSameElementsAs expected(("a = 3", AlwaysEdge)) + succOf("a = 3") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z += 2", AlwaysEdge)) + succOf("z += 2") should contain theSameElementsAs expected(("y", AlwaysEdge)) } "be correct for for-loop with for-in" in { @@ -530,193 +563,214 @@ class SimpleCfgCreationPassTests extends CfgTestFixture(() => new JsSrcCfgTestCp "be correct for nested for-loop" in { implicit val cpg: Cpg = code("for (x; y; z) { for (a; b; c) { u; } }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("a", TrueEdge), ("RET", FalseEdge)) - succOf("z") shouldBe expected(("y", AlwaysEdge)) - succOf("a") shouldBe expected(("b", AlwaysEdge)) - succOf("b") shouldBe expected(("u", TrueEdge), ("z", FalseEdge)) - succOf("c") shouldBe expected(("b", AlwaysEdge)) - succOf("u") shouldBe expected(("c", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("a", TrueEdge), ("RET", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("b") should contain theSameElementsAs expected(("u", TrueEdge), ("z", FalseEdge)) + succOf("c") should contain theSameElementsAs expected(("b", AlwaysEdge)) + succOf("u") should contain theSameElementsAs expected(("c", AlwaysEdge)) } "be correct for for-loop with empty condition" in { implicit val cpg: Cpg = code("for (;;) { a = 1; }") - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("a", TrueEdge), ("RET", FalseEdge)) - succOf("a") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("a = 1", AlwaysEdge)) - succOf("a = 1") shouldBe expected(("true", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("a", TrueEdge), ("RET", FalseEdge)) + succOf("a") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("a = 1", AlwaysEdge)) + succOf("a = 1") should contain theSameElementsAs expected(("true", AlwaysEdge)) } "be correct for for-loop with empty condition and break" in { implicit val cpg: Cpg = code("for (;;) { break; }") - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("break;", TrueEdge), ("RET", FalseEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("break;", TrueEdge), ("RET", FalseEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for for-loop with empty condition and continue" in { implicit val cpg: Cpg = code("for (;;) { continue; }") - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("continue;", TrueEdge), ("RET", FalseEdge)) - succOf("continue;") shouldBe expected(("true", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("continue;", TrueEdge), ("RET", FalseEdge)) + succOf("continue;") should contain theSameElementsAs expected(("true", AlwaysEdge)) } "be correct with empty condition with nested empty for-loop" in { implicit val cpg: Cpg = code("for (;;) { for (;;) { x; } }") - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("true", 1, TrueEdge), ("RET", FalseEdge)) - succOf("true", 1) shouldBe expected(("x", TrueEdge), ("true", 0, FalseEdge)) - succOf("x") shouldBe expected(("true", 1, AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("true", 1, TrueEdge), ("RET", FalseEdge)) + succOf("true", 1) should contain theSameElementsAs expected(("x", TrueEdge), ("true", 0, FalseEdge)) + succOf("x") should contain theSameElementsAs expected(("true", 1, AlwaysEdge)) } "be correct for for-loop with empty block" in { implicit val cpg: Cpg = code("for (;;) ;") - succOf(":program") shouldBe expected(("true", AlwaysEdge)) - succOf("true") shouldBe expected(("true", TrueEdge), ("RET", FalseEdge)) + succOf(":program") should contain theSameElementsAs expected(("true", AlwaysEdge)) + succOf("true") should contain theSameElementsAs expected(("true", TrueEdge), ("RET", FalseEdge)) } "be correct for simple if statement" in { implicit val cpg: Cpg = code("if (x) { y; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for simple if statement with else block" in { implicit val cpg: Cpg = code("if (x) { y; } else { z; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested if statement" in { implicit val cpg: Cpg = code("if (x) { if (y) { z; } }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("y", TrueEdge), ("RET", FalseEdge)) - succOf("y") shouldBe expected(("z", TrueEdge), ("RET", FalseEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("y", TrueEdge), ("RET", FalseEdge)) + succOf("y") should contain theSameElementsAs expected(("z", TrueEdge), ("RET", FalseEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for nested if statement with else-if chains" in { implicit val cpg: Cpg = code("if (a) { b; } else if (c) { d;} else { e; }") - succOf(":program") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("b", TrueEdge), ("c", FalseEdge)) - succOf("b") shouldBe expected(("RET", AlwaysEdge)) - succOf("c") shouldBe expected(("d", TrueEdge), ("e", FalseEdge)) - succOf("d") shouldBe expected(("RET", AlwaysEdge)) - succOf("e") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("a", AlwaysEdge)) + succOf("a") should contain theSameElementsAs expected(("b", TrueEdge), ("c", FalseEdge)) + succOf("b") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("c") should contain theSameElementsAs expected(("d", TrueEdge), ("e", FalseEdge)) + succOf("d") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("e") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch-case with single case" in { implicit val cpg: Cpg = code("switch (x) { case 1: y;}") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("case 1:", CaseEdge), ("RET", CaseEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch-case with multiple cases" in { implicit val cpg: Cpg = code("switch (x) { case 1: y; case 2: z;}") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected( + ("case 1:", CaseEdge), + ("case 2:", CaseEdge), + ("RET", CaseEdge) + ) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("case 2:", AlwaysEdge)) + succOf("case 2:") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch-case with multiple cases on the same spot" in { implicit val cpg: Cpg = code("switch (x) { case 1: case 2: y; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("2", AlwaysEdge)) - succOf("2") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected( + ("case 1:", CaseEdge), + ("case 2:", CaseEdge), + ("RET", CaseEdge) + ) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("case 2:", AlwaysEdge)) + succOf("case 2:") should contain theSameElementsAs expected(("2", AlwaysEdge)) + succOf("2") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch-case with default case" in { implicit val cpg: Cpg = code("switch (x) { default: y; }") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("default:", CaseEdge)) - succOf("default:") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("default:", CaseEdge)) + succOf("default:") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for switch-case with multiple cases and default combined" in { implicit val cpg: Cpg = code("switch (x) { case 1: y; break; default: z;}") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("default:", CaseEdge)) - succOf("case 1:") shouldBe expected(("1", AlwaysEdge)) - succOf("1") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("break;", AlwaysEdge)) - succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("default:") shouldBe expected(("z", AlwaysEdge)) - succOf("z") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("case 1:", CaseEdge), ("default:", CaseEdge)) + succOf("case 1:") should contain theSameElementsAs expected(("1", AlwaysEdge)) + succOf("1") should contain theSameElementsAs expected(("y", AlwaysEdge)) + succOf("y") should contain theSameElementsAs expected(("break;", AlwaysEdge)) + succOf("break;") should contain theSameElementsAs expected(("RET", AlwaysEdge)) + succOf("default:") should contain theSameElementsAs expected(("z", AlwaysEdge)) + succOf("z") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for constructor call with new" in { implicit val cpg: Cpg = code("""var x = new MyClass(arg1, arg2)""") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("_tmp_0", AlwaysEdge)) - succOf("_tmp_0") shouldBe expected((".alloc", AlwaysEdge)) - succOf(".alloc") shouldBe expected(("_tmp_0 = .alloc", AlwaysEdge)) - succOf("_tmp_0 = .alloc") shouldBe expected(("MyClass", AlwaysEdge)) - succOf("MyClass") shouldBe expected(("_tmp_0", 1, AlwaysEdge)) - succOf("_tmp_0", 1) shouldBe expected(("arg1", AlwaysEdge)) - succOf("arg1") shouldBe expected(("arg2", AlwaysEdge)) - succOf("arg2") shouldBe expected(("new MyClass(arg1, arg2)", AlwaysEdge)) - succOf("new MyClass(arg1, arg2)", NodeTypes.CALL) shouldBe expected(("_tmp_0", 2, AlwaysEdge)) - succOf("_tmp_0", 2) shouldBe expected(("new MyClass(arg1, arg2)", AlwaysEdge)) - succOf("new MyClass(arg1, arg2)") shouldBe expected(("var x = new MyClass(arg1, arg2)", AlwaysEdge)) - succOf("var x = new MyClass(arg1, arg2)") shouldBe expected(("RET", AlwaysEdge)) + succOf(":program") should contain theSameElementsAs expected(("x", AlwaysEdge)) + succOf("x") should contain theSameElementsAs expected(("_tmp_0", AlwaysEdge)) + succOf("_tmp_0") should contain theSameElementsAs expected((".alloc", AlwaysEdge)) + succOf(".alloc") should contain theSameElementsAs expected(("_tmp_0 = .alloc", AlwaysEdge)) + succOf("_tmp_0 = .alloc") should contain theSameElementsAs expected(("MyClass", AlwaysEdge)) + succOf("MyClass") should contain theSameElementsAs expected(("_tmp_0", 1, AlwaysEdge)) + succOf("_tmp_0", 1) should contain theSameElementsAs expected(("arg1", AlwaysEdge)) + succOf("arg1") should contain theSameElementsAs expected(("arg2", AlwaysEdge)) + succOf("arg2") should contain theSameElementsAs expected(("new MyClass(arg1, arg2)", AlwaysEdge)) + succOf("new MyClass(arg1, arg2)", NodeTypes.CALL) should contain theSameElementsAs expected( + ("_tmp_0", 2, AlwaysEdge) + ) + succOf("_tmp_0", 2) should contain theSameElementsAs expected(("new MyClass(arg1, arg2)", AlwaysEdge)) + succOf("new MyClass(arg1, arg2)") should contain theSameElementsAs expected( + ("var x = new MyClass(arg1, arg2)", AlwaysEdge) + ) + succOf("var x = new MyClass(arg1, arg2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } private def testForInOrOf()(implicit cpg: Cpg): Unit = { - succOf(":program") shouldBe expected(("_iterator_0", AlwaysEdge)) - succOf("_iterator_0") shouldBe expected(("arr", AlwaysEdge)) - succOf("arr") shouldBe expected((".iterator(arr)", AlwaysEdge)) - succOf(".iterator(arr)") shouldBe expected(("_iterator_0 = .iterator(arr)", AlwaysEdge)) - succOf("_iterator_0 = .iterator(arr)") shouldBe expected(("_result_0", AlwaysEdge)) - succOf("_result_0") shouldBe expected(("i", AlwaysEdge)) - succOf("i") shouldBe expected(("_result_0", 1, AlwaysEdge)) - succOf("_result_0", 1) shouldBe expected(("_iterator_0", 1, AlwaysEdge)) - succOf("_iterator_0", 1) shouldBe expected(("next", AlwaysEdge)) - succOf("next") shouldBe expected(("_iterator_0.next", AlwaysEdge)) - succOf("_iterator_0.next") shouldBe expected(("_iterator_0", 2, AlwaysEdge)) - succOf("_iterator_0", 2) shouldBe expected(("_iterator_0.next()", AlwaysEdge)) - succOf("_iterator_0.next()") shouldBe expected(("(_result_0 = _iterator_0.next())", AlwaysEdge)) - succOf("(_result_0 = _iterator_0.next())") shouldBe expected(("done", AlwaysEdge)) - succOf("done") shouldBe expected(("(_result_0 = _iterator_0.next()).done", AlwaysEdge)) - succOf("(_result_0 = _iterator_0.next()).done") shouldBe expected( + succOf(":program") should contain theSameElementsAs expected(("_iterator_0", AlwaysEdge)) + succOf("_iterator_0") should contain theSameElementsAs expected(("arr", AlwaysEdge)) + succOf("arr") should contain theSameElementsAs expected((".iterator(arr)", AlwaysEdge)) + succOf(".iterator(arr)") should contain theSameElementsAs expected( + ("_iterator_0 = .iterator(arr)", AlwaysEdge) + ) + succOf("_iterator_0 = .iterator(arr)") should contain theSameElementsAs expected( + ("_result_0", AlwaysEdge) + ) + succOf("_result_0") should contain theSameElementsAs expected(("i", AlwaysEdge)) + succOf("i") should contain theSameElementsAs expected(("_result_0", 1, AlwaysEdge)) + succOf("_result_0", 1) should contain theSameElementsAs expected(("_iterator_0", 1, AlwaysEdge)) + succOf("_iterator_0", 1) should contain theSameElementsAs expected(("next", AlwaysEdge)) + succOf("next") should contain theSameElementsAs expected(("_iterator_0.next", AlwaysEdge)) + succOf("_iterator_0.next") should contain theSameElementsAs expected(("_iterator_0", 2, AlwaysEdge)) + succOf("_iterator_0", 2) should contain theSameElementsAs expected(("_iterator_0.next()", AlwaysEdge)) + succOf("_iterator_0.next()") should contain theSameElementsAs expected( + ("(_result_0 = _iterator_0.next())", AlwaysEdge) + ) + succOf("(_result_0 = _iterator_0.next())") should contain theSameElementsAs expected(("done", AlwaysEdge)) + succOf("done") should contain theSameElementsAs expected(("(_result_0 = _iterator_0.next()).done", AlwaysEdge)) + succOf("(_result_0 = _iterator_0.next()).done") should contain theSameElementsAs expected( ("!(_result_0 = _iterator_0.next()).done", AlwaysEdge) ) import io.shiftleft.semanticcpg.language._ val code = cpg.method.ast.isBlock.code("for \\(var i.*foo.*}").code.head - succOf("!(_result_0 = _iterator_0.next()).done") shouldBe expected(("i", 1, TrueEdge), (code, FalseEdge)) - succOf(code) shouldBe expected(("RET", AlwaysEdge)) - - succOf("i", 1) shouldBe expected(("_result_0", 2, AlwaysEdge)) - succOf("_result_0", 2) shouldBe expected(("value", AlwaysEdge)) - succOf("value") shouldBe expected(("_result_0.value", AlwaysEdge)) - succOf("_result_0.value") shouldBe expected(("i = _result_0.value", AlwaysEdge)) - succOf("i = _result_0.value") shouldBe expected(("foo", AlwaysEdge)) - succOf("foo") shouldBe expected(("this", 1, AlwaysEdge)) - succOf("this", 1) shouldBe expected(("i", 2, AlwaysEdge)) - succOf("i", 2) shouldBe expected(("foo(i)", AlwaysEdge)) + succOf("!(_result_0 = _iterator_0.next()).done") should contain theSameElementsAs expected( + ("i", 1, TrueEdge), + (code, FalseEdge) + ) + succOf(code) should contain theSameElementsAs expected(("RET", AlwaysEdge)) + + succOf("i", 1) should contain theSameElementsAs expected(("_result_0", 2, AlwaysEdge)) + succOf("_result_0", 2) should contain theSameElementsAs expected(("value", AlwaysEdge)) + succOf("value") should contain theSameElementsAs expected(("_result_0.value", AlwaysEdge)) + succOf("_result_0.value") should contain theSameElementsAs expected(("i = _result_0.value", AlwaysEdge)) + succOf("i = _result_0.value") should contain theSameElementsAs expected(("foo", AlwaysEdge)) + succOf("foo") should contain theSameElementsAs expected(("this", 1, AlwaysEdge)) + succOf("this", 1) should contain theSameElementsAs expected(("i", 2, AlwaysEdge)) + succOf("i", 2) should contain theSameElementsAs expected(("foo(i)", AlwaysEdge)) val code2 = "{ foo(i) }" - succOf("foo(i)") shouldBe expected((code2, AlwaysEdge)) - succOf(code2) shouldBe expected(("_result_0", 1, AlwaysEdge)) + succOf("foo(i)") should contain theSameElementsAs expected((code2, AlwaysEdge)) + succOf(code2) should contain theSameElementsAs expected(("_result_0", 1, AlwaysEdge)) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/preprocessing/EjsPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/preprocessing/EjsPassTests.scala index 7eda70ca7a15..a1c8f9593d58 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/preprocessing/EjsPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/preprocessing/EjsPassTests.scala @@ -7,7 +7,7 @@ class EjsPassTests extends AstJsSrc2CpgSuite { "ejs files" should { - "be renamed correctly " in { + "be renamed correctly" in { val cpg = code( """ | @@ -20,6 +20,20 @@ class EjsPassTests extends AstJsSrc2CpgSuite { cpg.call.code.l.sorted shouldBe List("user.name") } + "be ignored at folders excluded by default" in { + val codeString = """ + | + |

Welcome <%= user.name %>

+ | + |""".stripMargin + val cpg = code(codeString, "index.js.ejs") + .moreCode(codeString, "node_modules/foo.js.ejs") + .moreCode(codeString, "vendor/bar.js.ejs") + .moreCode(codeString, "www/baz.js.ejs") + cpg.file.name.l shouldBe List("index.js.ejs") + cpg.call.code.l.sorted shouldBe List("user.name") + } + "be handled correctly" in { val cpg = code( """ diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala index eb93c807ba97..3e6a9fdaab97 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala @@ -31,7 +31,7 @@ class DataFlowCodeToCpgSuite extends Code2CpgFixture(() => new DataFlowTestCpg() protected implicit val context: EngineContext = EngineContext() protected def flowToResultPairs(path: Path): List[(String, Integer)] = - path.resultPairs().collect { case (firstElement: String, secondElement: Option[Integer]) => + path.resultPairs().collect { case (firstElement: String, secondElement) => (firstElement, secondElement.getOrElse(-1)) } } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/JsSrc2CpgSuite.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/JsSrc2CpgSuite.scala index f188520b43c9..43cbdc216aaa 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/JsSrc2CpgSuite.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/testfixtures/JsSrc2CpgSuite.scala @@ -1,16 +1,17 @@ package io.joern.jssrc2cpg.testfixtures -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.x2cpg.testfixtures.Code2CpgFixture class JsSrc2CpgSuite( fileSuffix: String = ".js", withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), withPostProcessing: Boolean = false ) extends Code2CpgFixture(() => new JsSrcDefaultTestCpg(fileSuffix) .withOssDataflow(withOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) diff --git a/joern-cli/frontends/kotlin2cpg/build.sbt b/joern-cli/frontends/kotlin2cpg/build.sbt index 83032b7c6024..825505cc8cea 100644 --- a/joern-cli/frontends/kotlin2cpg/build.sbt +++ b/joern-cli/frontends/kotlin2cpg/build.sbt @@ -23,4 +23,5 @@ libraryDependencies ++= Seq( enablePlugins(JavaAppPackaging, LauncherJarPlugin) trapExit := false -Test / fork := false +Test / fork := true +Test / javaOptions ++= Seq("-ea") diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Constants.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Constants.scala index 7591307422fd..7216d73956cf 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Constants.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Constants.scala @@ -1,36 +1,33 @@ package io.joern.kotlin2cpg object Constants { - val alloc = "alloc" - val caseNodeParserTypeName = "CaseNode" - val caseNodePrefix = "case" - val codeForLoweredForBlock = "FOR-BLOCK" // TODO: improve this - val collectionsIteratorName = "kotlin.collections.Iterator" - val companionObjectMemberName = "object" - val componentNPrefix = "component" - val defaultCaseNode = "default" - val empty = "" - val getIteratorMethodName = "iterator" - val hasNextIteratorMethodName = "hasNext" - val importKeyword = "import" - val init = io.joern.x2cpg.Defines.ConstructorMethodName - val iteratorPrefix = "iterator_" - val javaUtilIterator = "java.util.Iterator" - val lambdaBindingName = "invoke" // the underlying _invoke_ fn for Kotlin FunctionX types - val lambdaTypeDeclName = "LAMBDA_TYPE_DECL" - val nextIteratorMethodName = "next" - val codePropUndefinedValue = "" - val operatorSuffix = "" - val paramNameLambdaDestructureDecl = "DESTRUCTURE_PARAM" - val parserTypeName = "KOTLIN_PSI_PARSER" - val retCode = "RET" - val ret = "RET" - val root = "" - val this_ = "this" - val tmpLocalPrefix = "tmp_" - val tryCode = "try" - val unusedDestructuringEntryText = "_" - val unknownOperator = ".unknown" - val when = "when" - val wildcardImportName = "*" + val Alloc = "alloc" + val CaseNodeParserTypeName = "CaseNode" + val CaseNodePrefix = "case" + val CodeForLoweredForBlock = "FOR-BLOCK" // TODO: improve this + val CollectionsIteratorName = "kotlin.collections.Iterator" + val CompanionObjectMemberName = "object" + val ComponentNPrefix = "component" + val DefaultCaseNode = "default" + val Empty = "" + val GetIteratorMethodName = "iterator" + val HasNextIteratorMethodName = "hasNext" + val ImportKeyword = "import" + val IteratorPrefix = "iterator_" + val JavaUtilIterator = "java.util.Iterator" + val UnknownLambdaBindingName = "" + val UnknownLambdaBaseClass = "" + val LambdaTypeDeclName = "LAMBDA_TYPE_DECL" + val NextIteratorMethodName = "next" + val CodePropUndefinedValue = "" + val OperatorSuffix = "" + val DestructedParamNamePrefix = "" + val RetCode = "RET" + val Root = "" + val ThisName = "this" + val TmpLocalPrefix = "tmp_" + val UnusedDestructuringEntryText = "_" + val UnknownOperator = ".unknown" + val WhenKeyword = "when" + val WildcardImportName = "*" } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala index 2a19856dbb2e..984382b24142 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala @@ -4,10 +4,10 @@ import better.files.File import io.joern.kotlin2cpg.compiler.CompilerAPI import io.joern.kotlin2cpg.compiler.ErrorLoggingMessageCollector import io.joern.kotlin2cpg.files.SourceFilesPicker -import io.joern.kotlin2cpg.interop.JavasrcInterop +import io.joern.kotlin2cpg.interop.JavaSrcInterop import io.joern.kotlin2cpg.jar4import.UsesService import io.joern.kotlin2cpg.passes.* -import io.joern.kotlin2cpg.types.{ContentSourcesPicker, DefaultTypeInfoProvider, TypeRenderer} +import io.joern.kotlin2cpg.types.{ContentSourcesPicker, TypeInfoProvider} import io.joern.x2cpg.SourceFiles import io.joern.x2cpg.X2CpgFrontend import io.joern.x2cpg.X2Cpg.withNewEmptyCpg @@ -20,8 +20,10 @@ import io.joern.x2cpg.SourceFiles.filterFile import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.utils.IOUtils -import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment +import org.jetbrains.kotlin.cli.jvm.compiler.{KotlinCoreEnvironment, KotlinToJVMBytecodeCompiler} +import org.jetbrains.kotlin.com.intellij.openapi.util.Disposer import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.resolve.BindingContext import org.slf4j.LoggerFactory import java.nio.file.Files @@ -33,11 +35,9 @@ import scala.util.matching.Regex object Kotlin2Cpg { - private val logger = LoggerFactory.getLogger(getClass) - - private val parsingError: String = "KOTLIN2CPG_PARSING_ERROR" - private val jarExtension: String = ".jar" - private val importRegex: Regex = ".*import([^;]*).*".r + private val logger = LoggerFactory.getLogger(getClass) + private val JarExtension: String = ".jar" + private val ImportPattern: Regex = ".*import([^;]*).*".r private val defaultKotlinStdlibContentRootJarPaths = Seq( DefaultContentRootJarPath("jars/kotlin-stdlib-1.9.0.jar", isResource = true), @@ -45,9 +45,6 @@ object Kotlin2Cpg { DefaultContentRootJarPath("jars/kotlin-stdlib-jdk8-1.9.0.jar", isResource = true) ) - case class InputPair(content: String, fileName: String) - type InputProvider = () => InputPair - def postProcessingPass(cpg: Cpg): Unit = { new KotlinTypeRecoveryPassGenerator(cpg).generate().foreach(_.createAndApply()) new KotlinTypeHintCallLinker(cpg).createAndApply() @@ -132,7 +129,7 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { } else Seq() } - private def gatherJarsAtConfigClassPath(sourceDir: String, config: Config): Seq[String] = { + private def gatherJarsAtConfigClassPath(config: Config): Seq[String] = { val jarsAtConfigClassPath = findJarsIn(config.classpath) if (config.classpath.nonEmpty) { if (jarsAtConfigClassPath.nonEmpty) { @@ -150,7 +147,7 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { filesWithJavaExtension: List[String] ): Seq[DefaultContentRootJarPath] = { val stdlibJars = if (config.withStdlibJarsInClassPath) defaultKotlinStdlibContentRootJarPaths else Seq() - val jarsAtConfigClassPath = gatherJarsAtConfigClassPath(sourceDir, config) + val jarsAtConfigClassPath = gatherJarsAtConfigClassPath(config) val dependenciesPaths = gatherDependenciesPaths(sourceDir, config, filesWithJavaExtension) val defaultContentRootJars = stdlibJars ++ jarsAtConfigClassPath.map { path => DefaultContentRootJarPath(path, isResource = false) } ++ @@ -185,17 +182,20 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { sourceFiles } - private def runJavasrcInterop( + private def runJavaSrcInterop( cpg: Cpg, - sourceDir: String, config: Config, filesWithJavaExtension: List[String], kotlinAstCreatorTypes: List[String] ): Unit = { if (config.includeJavaSourceFiles && filesWithJavaExtension.nonEmpty) { - val javaAstCreator = JavasrcInterop.astCreationPass(config.inputPath, filesWithJavaExtension, cpg) + val javaAstCreator = JavaSrcInterop.astCreationPass(config.inputPath, filesWithJavaExtension, cpg) javaAstCreator.createAndApply() val javaAstCreatorTypes = javaAstCreator.global.usedTypes.keys().asScala.toList + + javaAstCreator.sourceParser.cleanupDelombokOutput() + javaAstCreator.clearJavaParserCaches() + TypeNodePass .withRegisteredTypes((javaAstCreatorTypes.toSet -- kotlinAstCreatorTypes.toSet).toList, cpg) .createAndApply() @@ -226,23 +226,23 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { new MetaDataPass(cpg, Languages.KOTLIN, config.inputPath).createAndApply() - val typeRenderer = new TypeRenderer(config.keepTypeArguments) - val astCreator = new AstCreationPass(sourceFiles, new DefaultTypeInfoProvider(environment, typeRenderer), cpg)( - config.schemaValidation - ) + val bindingContext = createBindingContext(environment) + val astCreator = new AstCreationPass(sourceFiles, bindingContext, cpg)(config.schemaValidation) astCreator.createAndApply() + Disposer.dispose(environment.getProjectEnvironment.getParentDisposable) + val kotlinAstCreatorTypes = astCreator.usedTypes() TypeNodePass.withRegisteredTypes(kotlinAstCreatorTypes, cpg).createAndApply() - runJavasrcInterop(cpg, sourceDir, config, filesWithJavaExtension, kotlinAstCreatorTypes) + runJavaSrcInterop(cpg, config, filesWithJavaExtension, kotlinAstCreatorTypes) new ConfigPass(configFiles, cpg).createAndApply() new DependenciesFromMavenCoordinatesPass(mavenCoordinates, cpg).createAndApply() } } private def importNamesForFilesAtPaths(paths: Seq[String]): Seq[String] = { - paths.flatMap(File(_).lines.filter(_.startsWith("import")).toSeq).map(importRegex.replaceAllIn(_, "$1").trim) + paths.flatMap(File(_).lines.filter(_.startsWith("import")).toSeq).map(ImportPattern.replaceAllIn(_, "$1").trim) } private def gatherGradleParams(config: Config) = { @@ -286,7 +286,7 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { dirs.foldLeft(Seq[String]())((acc, classpathEntry) => { val f = File(classpathEntry) val files = - if (f.isDirectory) f.listRecursively.filter(_.extension.getOrElse("") == jarExtension).map(_.toString) + if (f.isDirectory) f.listRecursively.filter(_.extension.getOrElse("") == JarExtension).map(_.toString) else Seq() acc ++ files }) @@ -315,4 +315,19 @@ class Kotlin2Cpg extends X2CpgFrontend[Config] with UsesService { fileContents <- Try(IOUtils.readEntireFile(Paths.get(fileName))).toOption } yield FileContentAtPath(fileContents, relPath, fileName) } + + private def createBindingContext(environment: KotlinCoreEnvironment): BindingContext = { + try { + logger.info("Running Kotlin compiler analysis...") + val t0 = System.nanoTime() + val analysisResult = KotlinToJVMBytecodeCompiler.INSTANCE.analyze(environment) + val t1 = System.nanoTime() + logger.info(s"Kotlin compiler analysis finished in `${(t1 - t0) / 1000000}` ms.") + analysisResult.getBindingContext + } catch { + case exc: Exception => + logger.error(s"Kotlin compiler analysis failed with exception `${exc.toString}`:`${exc.getMessage}`.", exc) + BindingContext.EMPTY + } + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Main.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Main.scala index 4d20c13c1e90..e4fa1f880b97 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Main.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Main.scala @@ -2,6 +2,7 @@ package io.joern.kotlin2cpg import io.joern.kotlin2cpg.Frontend.* import io.joern.x2cpg.{DependencyDownloadConfig, X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser case class DefaultContentRootJarPath(path: String, isResource: Boolean) @@ -91,13 +92,17 @@ private object Frontend { opt[Unit]("keep-type-arguments") .hidden() .action((_, c) => c.withKeepTypeArguments(true)) - .text("Type full names of variables keep their type arguments.") + .text("Type full names of variables keep their type arguments. (Deprecated, no effect.") ) } } -object Main extends X2CpgMain(cmdLineParser, new Kotlin2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new Kotlin2Cpg()) with FrontendHTTPServer[Config, Kotlin2Cpg] { + + override protected def newDefaultConfig(): Config = Config() + def run(config: Config, kotlin2cpg: Kotlin2Cpg): Unit = { - kotlin2cpg.run(config) + if (config.serverMode) { startup() } + else { kotlin2cpg.run(config) } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala index 11bfd8af9afd..ba8dd372ea56 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala @@ -3,9 +3,9 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.KtFileWithMeta import io.joern.kotlin2cpg.datastructures.Scope +import io.joern.kotlin2cpg.types.NameRenderer import io.joern.kotlin2cpg.types.TypeConstants import io.joern.kotlin2cpg.types.TypeInfoProvider -import io.joern.kotlin2cpg.types.TypeRenderer import io.joern.x2cpg.Ast import io.joern.x2cpg.AstCreatorBase import io.joern.x2cpg.AstNodeBuilder @@ -13,33 +13,38 @@ import io.joern.x2cpg.Defines import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Global import io.joern.x2cpg.datastructures.Stack.* +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.utils.NodeBuilders import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.passes.IntervalKeyPool +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import org.jetbrains.kotlin.com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.DeclarationDescriptor import org.jetbrains.kotlin.descriptors.DescriptorVisibilities import org.jetbrains.kotlin.descriptors.DescriptorVisibility +import org.jetbrains.kotlin.descriptors.FunctionDescriptor import org.jetbrains.kotlin.lexer.KtToken import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.resolve.BindingContext import org.slf4j.Logger import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder -import java.io.PrintWriter -import java.io.StringWriter import scala.annotation.tailrec import scala.collection.mutable import scala.jdk.CollectionConverters.* +import scala.util.Try -case class BindingInfo(node: NewBinding, edgeMeta: Seq[(NewNode, NewNode, String)]) -case class ClosureBindingDef(node: NewClosureBinding, captureEdgeTo: NewMethodRef, refEdgeTo: NewNode) +object AstCreator { + case class AnonymousObjectContext(declaration: KtElement) + case class BindingInfo(node: NewBinding, edgeMeta: Seq[(NewNode, NewNode, String)]) + case class ClosureBindingDef(node: NewClosureBinding, captureEdgeTo: NewMethodRef, refEdgeTo: NewNode) +} -class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvider, global: Global)(implicit +class AstCreator(fileWithMeta: KtFileWithMeta, bindingContext: BindingContext, global: Global)(implicit withSchemaValidation: ValidationMode ) extends AstCreatorBase(fileWithMeta.filename) with AstForDeclarationsCreator @@ -49,26 +54,31 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid with AstForExpressionsCreator with AstNodeBuilder[PsiElement, AstCreator] { + import AstCreator.BindingInfo + import AstCreator.ClosureBindingDef + protected val closureBindingDefQueue: mutable.ArrayBuffer[ClosureBindingDef] = mutable.ArrayBuffer.empty protected val bindingInfoQueue: mutable.ArrayBuffer[BindingInfo] = mutable.ArrayBuffer.empty protected val lambdaAstQueue: mutable.ArrayBuffer[Ast] = mutable.ArrayBuffer.empty protected val lambdaBindingInfoQueue: mutable.ArrayBuffer[BindingInfo] = mutable.ArrayBuffer.empty protected val methodAstParentStack: Stack[NewNode] = new Stack() - protected val tmpKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue) - protected val iteratorKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue) + protected val tmpKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue) + protected val destructedParamKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue) + protected val iteratorKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue) protected val relativizedPath: String = fileWithMeta.relativizedPath protected val scope: Scope[String, DeclarationNew, NewNode] = new Scope() protected val debugScope: mutable.Stack[KtDeclaration] = mutable.Stack.empty[KtDeclaration] + protected val nameRenderer = new NameRenderer() + protected val bindingUtils = new BindingContextUtils(bindingContext) + protected val typeInfoProvider = new TypeInfoProvider(bindingContext) + def createAst(): DiffGraphBuilder = { - implicit val typeInfoProvider: TypeInfoProvider = xTypeInfoProvider logger.debug(s"Started parsing file `${fileWithMeta.filename}`.") - - val defaultTypes = - Set(TypeConstants.javaLangObject, TypeConstants.kotlin) ++ TypeRenderer.primitiveArrayMappings.keys + val defaultTypes = Set(TypeConstants.JavaLangObject, TypeConstants.Kotlin) defaultTypes.foreach(registerType) storeInDiffGraph(astForFile(fileWithMeta)) diffGraph @@ -84,33 +94,77 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid typeName } - // TODO: use this everywhere in kotlin2cpg instead of manual .getText calls - override def code(element: PsiElement): String = shortenCode(element.getText) - - override def line(element: PsiElement): Option[Int] = { - try { - Some( - element.getContainingFile.getViewProvider.getDocument - .getLineNumber(element.getTextOffset) + 1 - ) - } catch { - case _: Throwable => None + protected def getFallback[T]( + expr: KtExpression, + propertyExtractor: FunctionDescriptor => Option[T] + ): Option[FunctionDescriptor] = { + val candidates = bindingUtils.getAmbiguousCalledFunctionDescs(expr) + if (candidates.isEmpty) { + return None + } + + val candidateProperties = candidates.map(propertyExtractor) + val allPropertiesEqual = candidateProperties.forall(_ == candidateProperties.head) + + if (allPropertiesEqual) { + candidates.headOption + } else { + None } } - override def column(element: PsiElement): Option[Int] = { - try { - val lineNumber = - element.getContainingFile.getViewProvider.getDocument - .getLineNumber(element.getTextOffset) - val lineOffset = - element.getContainingFile.getViewProvider.getDocument.getLineStartOffset(lineNumber) - Some(element.getTextOffset - lineOffset) - } catch { - case _: Throwable => None + protected def getAmbiguousFuncDescIfFullNamesEqual(expr: KtExpression): Option[FunctionDescriptor] = { + getFallback(expr, nameRenderer.descFullName) + } + + protected def getAmbiguousFuncDescIfSignaturesEqual(expr: KtExpression): Option[FunctionDescriptor] = { + getFallback(expr, nameRenderer.funcDescSignature) + } + + protected def calleeFullnameAndSignature( + calleeExpr: KtExpression, + fullNameFallback: => String, + signatureFallback: => String + ): (String, String) = { + val funcDesc = bindingUtils.getCalledFunctionDesc(calleeExpr) + val descFullName = funcDesc + .orElse(getAmbiguousFuncDescIfFullNamesEqual(calleeExpr)) + .flatMap(nameRenderer.descFullName) + .getOrElse(fullNameFallback) + val signature = funcDesc + .orElse(getAmbiguousFuncDescIfSignaturesEqual(calleeExpr)) + .flatMap(nameRenderer.funcDescSignature) + .getOrElse(signatureFallback) + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + (fullName, signature) + } + + protected def getCalleeExpr(expr: KtExpression): KtExpression = { + expr match { + case qualifiedExpression: KtQualifiedExpression => + getCalleeExpr(qualifiedExpression.getSelectorExpression) + case callExpr: KtCallExpression => + callExpr.getCalleeExpression } } + // TODO: use this everywhere in kotlin2cpg instead of manual .getText calls + override def code(element: PsiElement): String = + shortenCode(element.getText) + + override def line(element: PsiElement): Option[Int] = + Try(element.getContainingFile.getViewProvider.getDocument.getLineNumber(element.getTextOffset) + 1).toOption + + override def column(element: PsiElement): Option[Int] = { + for { + lineNumber <- line(element) + lineOffset <- Try( + element.getContainingFile.getViewProvider.getDocument.getLineStartOffset(lineNumber - 1) + ).toOption + } yield element.getTextOffset - lineOffset + } + override def lineEnd(element: PsiElement): Option[Int] = { val lastElement = element match { case namedFn: KtNamedFunction => @@ -135,7 +189,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid protected def getName(node: NewImport): String = { val isWildcard = node.isWildcard.getOrElse(false) - if (isWildcard) Constants.wildcardImportName + if (isWildcard) Constants.WildcardImportName else node.importedEntity.getOrElse("") } @@ -204,7 +258,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid argIdxMaybe: Option[Int], argNameMaybe: Option[String] = None, annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + ): Seq[Ast] = { expr match { case typedExpr: KtAnnotatedExpression => astsForExpression( @@ -282,7 +336,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid } } - private def astForFile(fileWithMeta: KtFileWithMeta)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astForFile(fileWithMeta: KtFileWithMeta): Ast = { val ktFile = fileWithMeta.f val importDirectives = ktFile.getImportList.getImports.asScala @@ -295,13 +349,13 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid val packageName = ktFile.getPackageFqName.toString val node = - if (packageName == Constants.root) + if (packageName == Constants.Root) { NodeBuilders.newNamespaceBlockNode( NamespaceTraversal.globalNamespaceName, NamespaceTraversal.globalNamespaceName, relativizedPath ) - else { + } else { val name = packageName.split("\\.").lastOption.getOrElse("") NodeBuilders.newNamespaceBlockNode(name, packageName, relativizedPath) } @@ -318,8 +372,8 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid methodAstParentStack.push(fakeGlobalMethod) scope.pushNewScope(fakeGlobalMethod) - val blockNode_ = blockNode(ktFile, "", registerType(TypeConstants.any)) - val methodReturn = newMethodReturnNode(TypeConstants.any, None, None, None) + val blockNode_ = blockNode(ktFile, "", registerType(TypeConstants.Any)) + val methodReturn = newMethodReturnNode(TypeConstants.Any, None, None, None) val declarationsAsts = ktFile.getDeclarations.asScala.flatMap(astsForDeclaration) val fileNode = NewFile().name(fileWithMeta.relativizedPath) @@ -338,16 +392,14 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid Ast(fileNode).withChildren(namespaceBlockAst :: namespaceBlocksForImports) } - def astsForDeclaration(decl: KtDeclaration)(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + def astsForDeclaration(decl: KtDeclaration): Seq[Ast] = { debugScope.push(decl) val result = try { decl match { - case c: KtClass => astsForClassOrObject(c) - case o: KtObjectDeclaration => astsForClassOrObject(o) - case n: KtNamedFunction => - val isExtensionFn = typeInfoProvider.isExtensionFn(n) - astsForMethod(n, isExtensionFn) + case c: KtClass => astsForClassOrObject(c) + case o: KtObjectDeclaration => astsForClassOrObject(o) + case n: KtNamedFunction => astsForMethod(n) case t: KtTypeAlias => Seq(astForTypeAlias(t)) case s: KtSecondaryConstructor => Seq(astForUnknown(s, None, None)) case p: KtProperty => astsForProperty(p) @@ -359,12 +411,9 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid } } catch { case exception: Exception => - val declText = decl.getText - val stringWriter = new StringWriter() - val printWriter = new PrintWriter(stringWriter) - exception.printStackTrace(printWriter) logger.warn( - s"Caught exception while processing decl in this file `$relativizedPath`:\n$declText\n${stringWriter.toString}" + s"Caught exception while processing decl in this file `$relativizedPath`:\n${decl.getText}", + exception ) Seq() } @@ -377,45 +426,49 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val node = unknownNode(expr, Option(expr).map(_.getText).getOrElse(Constants.codePropUndefinedValue)) + ): Ast = { + val node = unknownNode(expr, Option(expr).map(_.getText).getOrElse(Constants.CodePropUndefinedValue)) Ast(withArgumentIndex(node, argIdx).argumentName(argNameMaybe)) .withChildren(annotations.map(astForAnnotationEntry)) } protected def assignmentAstForDestructuringEntry( entry: KtDestructuringDeclarationEntry, - componentNReceiverName: String, - componentNTypeFullName: String, + rhsBaseAst: Ast, componentIdx: Integer - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val entryTypeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any)) + ): Ast = { + val entryTypeFullName = registerType( + bindingUtils + .getVariableDesc(entry) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) + ) val assignmentLHSNode = identifierNode(entry, entry.getText, entry.getText, entryTypeFullName) val assignmentLHSAst = astWithRefEdgeMaybe(assignmentLHSNode.name, assignmentLHSNode) - val componentNIdentifierNode = - identifierNode(entry, componentNReceiverName, componentNReceiverName, componentNTypeFullName) - .argumentIndex(0) - - val fallbackSignature = s"${Defines.UnresolvedNamespace}()" - val fallbackFullName = - s"${Defines.UnresolvedNamespace}${Constants.componentNPrefix}$componentIdx:$fallbackSignature" - val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(entry, (fallbackFullName, fallbackSignature)) - val componentNCallCode = s"$componentNReceiverName.${Constants.componentNPrefix}$componentIdx()" + val desc = bindingUtils.getCalledFunctionDesc(entry) + val descFullName = desc + .flatMap(nameRenderer.descFullName) + .getOrElse(s"${Defines.UnresolvedNamespace}${Constants.ComponentNPrefix}$componentIdx") + val signature = desc + .flatMap(nameRenderer.funcDescSignature) + .getOrElse(s"${Defines.UnresolvedSignature}()") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + val componentNCallCode = + s"${rhsBaseAst.root.get.asInstanceOf[ExpressionNew].code}.${Constants.ComponentNPrefix}$componentIdx()" val componentNCallNode = callNode( entry, componentNCallCode, - s"${Constants.componentNPrefix}$componentIdx", + s"${Constants.ComponentNPrefix}$componentIdx", fullName, DispatchTypes.DYNAMIC_DISPATCH, Some(signature), Some(entryTypeFullName) ) - val componentNIdentifierAst = astWithRefEdgeMaybe(componentNIdentifierNode.name, componentNIdentifierNode) val componentNAst = - callAst(componentNCallNode, Seq(), Option(componentNIdentifierAst)) + callAst(componentNCallNode, Seq(), Option(rhsBaseAst)) val assignmentCallNode = NodeBuilders.newOperatorCallNode( Operators.assignment, @@ -427,15 +480,12 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid callAst(assignmentCallNode, List(assignmentLHSAst, componentNAst)) } - protected def astDerivedFullNameWithSignature(expr: KtQualifiedExpression, argAsts: List[Ast])(implicit - typeInfoProvider: TypeInfoProvider - ): (String, String) = { + protected def astDerivedFullNameWithSignature(expr: KtQualifiedExpression, argAsts: List[Ast]): (String, String) = { val astDerivedMethodFullName = expr.getSelectorExpression match { case expression: KtCallExpression => val receiverPlaceholderType = Defines.UnresolvedNamespace - val shortName = expr.getSelectorExpression.getFirstChild.getText - val args = expression.getValueArguments - s"$receiverPlaceholderType.$shortName:${typeInfoProvider.anySignature(args.asScala.toList)}" + val shortName = expression.getFirstChild.getText + s"$receiverPlaceholderType.$shortName" case _: KtNameReferenceExpression => Operators.fieldAccess case _ => @@ -443,23 +493,20 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid "" } - val astDerivedSignature = typeInfoProvider.anySignature(argAsts) + val astDerivedSignature = s"${Defines.UnresolvedSignature}(${argAsts.size})" (astDerivedMethodFullName, astDerivedSignature) } - protected def selectorExpressionArgAsts( - expr: KtQualifiedExpression - )(implicit typeInfoProvider: TypeInfoProvider): List[Ast] = { - expr.getSelectorExpression match { - case typedExpr: KtCallExpression => - withIndex(typedExpr.getValueArguments.asScala.toSeq) { case (arg, idx) => - astsForExpression(arg.getArgumentExpression, Some(idx)) - }.flatten.toList - case typedExpr: KtNameReferenceExpression => - val node = fieldIdentifierNode(typedExpr, typedExpr.getText, typedExpr.getText).argumentIndex(2) - List(Ast(node)) - case _ => List() - } + protected def astsForKtCallExpressionArguments(callExpr: KtCallExpression, startIndex: Int = 1): List[Ast] = { + withIndex(callExpr.getValueArguments.asScala.toSeq) { case (arg, idx) => + val argumentNameMaybe = Option(arg.getArgumentName).map(_.getText) + astsForExpression(arg.getArgumentExpression, Some(startIndex + idx - 1), argumentNameMaybe) + }.flatten.toList + } + + protected def selectorExpressionArgAsts(expr: KtQualifiedExpression, startIndex: Int = 1): List[Ast] = { + val callExpr = expr.getSelectorExpression.asInstanceOf[KtCallExpression] + astsForKtCallExpressionArguments(callExpr, startIndex) } protected def modifierTypeForVisibility(visibility: DescriptorVisibility): String = { @@ -473,4 +520,29 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid ModifierTypes.INTERNAL else "UNKNOWN" } + + protected def addToLambdaBindingInfoQueue( + bindingNode: NewBinding, + typeDecl: NewTypeDecl, + methodNode: NewMethod + ): Unit = { + lambdaBindingInfoQueue.prepend( + BindingInfo(bindingNode, Seq((typeDecl, bindingNode, EdgeTypes.BINDS), (bindingNode, methodNode, EdgeTypes.REF))) + ) + } + + protected def exprTypeFullName(expr: KtExpression): Option[String] = { + bindingUtils.getExprType(expr).flatMap(nameRenderer.typeFullName) + } + + protected def fullNameByImportPath(typeRef: KtTypeReference, file: KtFile): Option[String] = { + if (typeRef == null) { None } + else { + val typeRefText = typeRef.getText.stripSuffix("?") + file.getImportList.getImports.asScala.collectFirst { + case directive if directive.getImportedName != null && directive.getImportedName.toString == typeRefText => + directive.getImportPath.getPathStr + } + } + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForDeclarationsCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForDeclarationsCreator.scala index b9474e5b8415..467894174148 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForDeclarationsCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForDeclarationsCreator.scala @@ -3,9 +3,7 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.psi.PsiUtils import io.joern.kotlin2cpg.psi.PsiUtils.nonUnderscoreDestructuringEntries -import io.joern.kotlin2cpg.types.AnonymousObjectContext import io.joern.kotlin2cpg.types.TypeConstants -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.Ast import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.Defines @@ -27,55 +25,83 @@ import io.shiftleft.semanticcpg.language.* import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.psi.* +import scala.collection.mutable import scala.jdk.CollectionConverters.* import scala.util.Random trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - private def isAbstract(ktClass: KtClassOrObject)(implicit typeInfoProvider: TypeInfoProvider): Boolean = { - typeInfoProvider.modality(ktClass).contains(Modality.ABSTRACT) - } + import AstCreator.AnonymousObjectContext + import AstCreator.BindingInfo def astsForClassOrObject( ktClass: KtClassOrObject, ctx: Option[AnonymousObjectContext] = None, annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + ): Seq[Ast] = { val className = ctx match { case Some(_) => "anonymous_obj" case None => ktClass.getName } - val explicitFullName = { - val fqName = ktClass.getContainingKtFile.getPackageFqName.toString - s"$fqName.$className" + val classDesc = bindingUtils.getClassDesc(ktClass) + + val classFullName = + nameRenderer.descFullName(classDesc).getOrElse { + val fqName = ktClass.getContainingKtFile.getPackageFqName.toString + s"$fqName.$className" + } + registerType(classFullName) + + val baseTypeFullNames = + ktClass.getSuperTypeListEntries.asScala + .flatMap { superTypeEntry => + val typeRef = superTypeEntry.getTypeReference + val superType = bindingUtils + .getTypeRefType(typeRef) + .flatMap(nameRenderer.typeFullName) + + superType.orElse { + fullNameByImportPath(typeRef, ktClass.getContainingKtFile) + } + } + .to(mutable.ArrayBuffer) + + if (baseTypeFullNames.isEmpty) { + baseTypeFullNames.append(TypeConstants.JavaLangObject) } - val classFullName = registerType(typeInfoProvider.fullName(ktClass, explicitFullName, ctx)) - val explicitBaseTypeFullNames = ktClass.getSuperTypeListEntries.asScala - .map(_.getTypeAsUserType) - .collect { case t if t != null => t.getText } - .map { typ => typeInfoProvider.typeFromImports(typ, ktClass.getContainingKtFile).getOrElse(typ) } - .toList - - val baseTypeFullNames = typeInfoProvider.inheritanceTypes(ktClass, explicitBaseTypeFullNames) + baseTypeFullNames.foreach(registerType) - val outBaseTypeFullNames = Option(baseTypeFullNames).filter(_.nonEmpty).getOrElse(Seq(TypeConstants.javaLangObject)) - val typeDecl = typeDeclNode(ktClass, className, classFullName, relativizedPath, outBaseTypeFullNames, None) + val typeDecl = typeDeclNode(ktClass, className, classFullName, relativizedPath, baseTypeFullNames.toSeq, None) scope.pushNewScope(typeDecl) methodAstParentStack.push(typeDecl) val primaryCtor = ktClass.getPrimaryConstructor val constructorParams = ktClass.getPrimaryConstructorParameters.asScala.toList - val defaultSignature = Option(primaryCtor) - .map { _ => typeInfoProvider.anySignature(constructorParams) } - .getOrElse(s"${TypeConstants.void}()") - val defaultFullName = s"$classFullName.${TypeConstants.initPrefix}:$defaultSignature" - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(primaryCtor, (defaultFullName, defaultSignature)) - val primaryCtorMethodNode = methodNode(primaryCtor, TypeConstants.initPrefix, fullName, signature, relativizedPath) + + val (fullName, signature) = + if (primaryCtor != null) { + val constructorDesc = bindingUtils.getConstructorDesc(primaryCtor) + val descFullName = nameRenderer + .descFullName(constructorDesc) + .getOrElse(s"$classFullName.${Defines.ConstructorMethodName}") + val signature = nameRenderer + .funcDescSignature(constructorDesc) + .getOrElse(s"${Defines.UnresolvedSignature}(${primaryCtor.getValueParameters.size()})") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + (fullName, signature) + } else { + val descFullName = s"$classFullName.${Defines.ConstructorMethodName}" + val signature = s"${TypeConstants.Void}()" + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + (fullName, signature) + } + val primaryCtorMethodNode = + methodNode(primaryCtor, Defines.ConstructorMethodName, fullName, signature, relativizedPath) val ctorThisParam = NodeBuilders.newThisParameterNode(typeFullName = classFullName, dynamicTypeHintFullName = Seq(classFullName)) - scope.addToScope(Constants.this_, ctorThisParam) + scope.addToScope(Constants.ThisName, ctorThisParam) val constructorParamsAsts = Seq(Ast(ctorThisParam)) ++ withIndex(constructorParams) { (p, idx) => astForParameter(p, idx) @@ -97,12 +123,12 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { if (initializerAsts.size == 1) initializerAsts.head else Ast(unknownNode(decl, "")) - val thisIdentifier = newIdentifierNode(Constants.this_, classFullName, Seq(classFullName)) - val thisAst = astWithRefEdgeMaybe(Constants.this_, thisIdentifier) + val thisIdentifier = newIdentifierNode(Constants.ThisName, classFullName, Seq(classFullName)) + val thisAst = astWithRefEdgeMaybe(Constants.ThisName, thisIdentifier) val fieldIdentifier = fieldIdentifierNode(decl, decl.getName, decl.getName) val fieldAccessCall = NodeBuilders - .newOperatorCallNode(Operators.fieldAccess, s"${Constants.this_}.${fieldIdentifier.canonicalName}", None) + .newOperatorCallNode(Operators.fieldAccess, s"${Constants.ThisName}.${fieldIdentifier.canonicalName}", None) val fieldAccessCallAst = callAst(fieldAccessCall, List(thisAst, Ast(fieldIdentifier))) val assignmentNode = NodeBuilders @@ -114,7 +140,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val anonymousInitAsts = anonymousInitExpressions.flatMap(astsForExpression(_, None)) val constructorMethodReturn = newMethodReturnNode( - TypeConstants.void, + TypeConstants.Void, None, line(ktClass.getPrimaryConstructor), column(ktClass.getPrimaryConstructor) @@ -123,7 +149,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { primaryCtorMethodNode, constructorParamsAsts, blockAst( - blockNode(ktClass, "", TypeConstants.void), + blockNode(ktClass, "", TypeConstants.Void), memberSetCalls ++ memberInitializerSetCalls ++ anonymousInitAsts ), constructorMethodReturn, @@ -136,8 +162,10 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val membersFromPrimaryCtorAsts = ktClass.getPrimaryConstructorParameters.asScala.toList.collect { case param if param.hasValOrVar => - val typeFullName = registerType(typeInfoProvider.parameterType(param, TypeConstants.any)) - val memberNode_ = memberNode(param, param.getName, param.getName, typeFullName) + val typeFullName = registerType( + nameRenderer.typeFullName(bindingUtils.getVariableDesc(param).get.getType).getOrElse(TypeConstants.Any) + ) + val memberNode_ = memberNode(param, param.getName, param.getName, typeFullName) scope.addToScope(param.getName, memberNode_) Ast(memberNode_) } @@ -145,7 +173,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val primaryCtorCall = callNode( ktClass.getPrimaryConstructor, - TypeConstants.initPrefix, + Defines.ConstructorMethodName, primaryCtorMethodNode.name, primaryCtorMethodNode.fullName, DispatchTypes.STATIC_DISPATCH, @@ -169,14 +197,14 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val innerTypeDeclAsts = classDeclarations.toSeq .collectAll[KtClassOrObject] - .filterNot(typeInfoProvider.isCompanionObject) + .filterNot(desc => bindingUtils.getClassDesc(desc).isCompanionObject) .flatMap(astsForDeclaration(_)) val classFunctions = Option(ktClass.getBody) .map(_.getFunctions.asScala.collect { case f: KtNamedFunction => f }) .getOrElse(List()) val methodAsts = classFunctions.toSeq.flatMap { classFn => - astsForMethod(classFn, needsThisParameter = true, withVirtualModifier = true) + astsForMethod(classFn, withVirtualModifier = true) } val bindingsInfo = methodAsts.flatMap(_.root.collectAll[NewMethod]).map { _methodNode => val node = newBindingNode(_methodNode.name, _methodNode.signature, _methodNode.fullName) @@ -185,7 +213,11 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val annotationAsts = ktClass.getAnnotationEntries.asScala.map(astForAnnotationEntry).toSeq - val modifiers = if (isAbstract(ktClass)) List(Ast(NodeBuilders.newModifierNode(ModifierTypes.ABSTRACT))) else Nil + val modifiers = if (classDesc.getModality == Modality.ABSTRACT) { + List(Ast(NodeBuilders.newModifierNode(ModifierTypes.ABSTRACT))) + } else { + Nil + } val children = methodAsts ++ List(constructorAst) ++ membersFromPrimaryCtorAsts ++ secondaryConstructorAsts ++ _componentNMethodAsts.toList ++ memberAsts ++ annotationAsts ++ modifiers @@ -193,17 +225,18 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { (List(ctorBindingInfo) ++ bindingsInfo ++ componentNBindingsInfo).foreach(bindingInfoQueue.prepend) - val finalAst = if (typeInfoProvider.isCompanionObject(ktClass)) { + val finalAst = if (classDesc.isCompanionObject) { val companionMemberTypeFullName = ktClass.getParent.getParent match { - case c: KtClassOrObject => typeInfoProvider.typeFullName(c, TypeConstants.any) - case _ => TypeConstants.any + case c: KtClassOrObject => + nameRenderer.descFullName(bindingUtils.getClassDesc(c)).getOrElse(TypeConstants.Any) + case _ => TypeConstants.Any } registerType(companionMemberTypeFullName) val companionObjectMember = memberNode( ktClass, - Constants.companionObjectMemberName, - Constants.companionObjectMemberName, + Constants.CompanionObjectMemberName, + Constants.CompanionObjectMemberName, companionMemberTypeFullName ) ast.withChild(Ast(companionObjectMember)) @@ -217,19 +250,22 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { Seq(finalAst.withChildren(annotations.map(astForAnnotationEntry))) ++ companionObjectAsts ++ innerTypeDeclAsts } - private def memberSetCallAst(param: KtParameter, classFullName: String)(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { - val typeFullName = registerType(typeInfoProvider.typeFullName(param, TypeConstants.any)) + private def memberSetCallAst(param: KtParameter, classFullName: String): Ast = { + val typeFullName = registerType( + bindingUtils + .getVariableDesc(param) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) + ) val paramName = param.getName val paramIdentifier = identifierNode(param, paramName, paramName, typeFullName) val paramIdentifierAst = astWithRefEdgeMaybe(paramName, paramIdentifier) - val thisIdentifier = newIdentifierNode(Constants.this_, classFullName, Seq(classFullName)) - val thisAst = astWithRefEdgeMaybe(Constants.this_, thisIdentifier) + val thisIdentifier = newIdentifierNode(Constants.ThisName, classFullName, Seq(classFullName)) + val thisAst = astWithRefEdgeMaybe(Constants.ThisName, thisIdentifier) val fieldIdentifier = fieldIdentifierNode(param, paramName, paramName) val fieldAccessCall = - NodeBuilders.newOperatorCallNode(Operators.fieldAccess, s"${Constants.this_}.$paramName", Option(typeFullName)) + NodeBuilders.newOperatorCallNode(Operators.fieldAccess, s"${Constants.ThisName}.$paramName", Option(typeFullName)) val fieldAccessCallAst = callAst(fieldAccessCall, List(thisAst, Ast(fieldIdentifier))) val assignmentNode = @@ -237,9 +273,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { callAst(assignmentNode, List(fieldAccessCallAst, paramIdentifierAst)) } - private def astsForDestructuringDeclarationWithRHS( - expr: KtDestructuringDeclaration - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + private def astsForDestructuringDeclarationWithRHS(expr: KtDestructuringDeclaration): Seq[Ast] = { val typedInit = Option(expr.getInitializer).collect { case c: KtCallExpression => c case dqe: KtDotQualifiedExpression => dqe @@ -255,21 +289,15 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { return Seq() } val rhsCall = typedInit.get - val callRhsTypeFullName = registerType(typeInfoProvider.expressionType(rhsCall, TypeConstants.any)) - - val destructuringEntries = nonUnderscoreDestructuringEntries(expr) - val localsForEntries = destructuringEntries.map { entry => - val typeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any)) - val node = localNode(entry, entry.getName, entry.getName, typeFullName) - scope.addToScope(node.name, node) - Ast(node) - } + val callRhsTypeFullName = registerType(exprTypeFullName(rhsCall).getOrElse(TypeConstants.Any)) + + val localsForEntries = localsForDestructuringEntries(expr) val isCtor = expr.getInitializer match { case _: KtCallExpression => typeInfoProvider.isConstructorCall(rhsCall).getOrElse(false) case _ => false } - val tmpName = s"${Constants.tmpLocalPrefix}${tmpKeyPool.next}" + val tmpName = s"${Constants.TmpLocalPrefix}${tmpKeyPool.next}" val localForTmpNode = localNode(expr, tmpName, tmpName, callRhsTypeFullName) scope.addToScope(localForTmpNode.name, localForTmpNode) val localForTmpAst = Ast(localForTmpNode) @@ -280,13 +308,13 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { if (isCtor) { val assignmentRhsNode = NodeBuilders.newOperatorCallNode( Operators.alloc, - Constants.alloc, + Constants.Alloc, Option(localForTmpNode.typeFullName), line(expr), column(expr) ) val assignmentNode = - NodeBuilders.newOperatorCallNode(Operators.assignment, s"$tmpName = ${Constants.alloc}", None) + NodeBuilders.newOperatorCallNode(Operators.assignment, s"$tmpName = ${Constants.Alloc}", None) callAst(assignmentNode, List(assignmentLhsAst, Ast(assignmentRhsNode))) } else { expr.getInitializer match { @@ -302,7 +330,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val assignmentNode = NodeBuilders.newOperatorCallNode(Operators.assignment, s"$tmpName = ${rhsCall.getText}", None) val assignmentRhsAst = - astsForExpression(rhsCall, None).headOption.getOrElse(Ast(unknownNode(rhsCall, Constants.empty))) + astsForExpression(rhsCall, None).headOption.getOrElse(Ast(unknownNode(rhsCall, Constants.Empty))) callAst(assignmentNode, List(assignmentLhsAst, assignmentRhsAst)) } } @@ -311,28 +339,35 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val initReceiverNode = identifierNode(expr, tmpName, tmpName, localForTmpNode.typeFullName).argumentIndex(0) val initReceiverAst = Ast(initReceiverNode).withRefEdge(initReceiverNode, localForTmpNode) - - val argAsts = withIndex(call.getValueArguments.asScala.toSeq) { case (arg, idx) => - astsForExpression(arg.getArgumentExpression, Some(idx)) - }.flatten - - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(call, (TypeConstants.any, TypeConstants.any)) - registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val argAsts = astsForKtCallExpressionArguments(call) + val (fullName, signature) = + calleeFullnameAndSignature( + getCalleeExpr(rhsCall), + Defines.UnresolvedNamespace, + s"${Defines.UnresolvedSignature}(${call.getValueArguments.size()})" + ) + registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val initCallNode = callNode( expr, - Constants.init, - Constants.init, + Defines.ConstructorMethodName, + Defines.ConstructorMethodName, fullName, DispatchTypes.STATIC_DISPATCH, Some(signature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) Seq(callAst(initCallNode, argAsts, Some(initReceiverAst), None)) case _ => Seq() } - val assignmentsForEntries = destructuringEntries.zipWithIndex.map { case (entry, idx) => - assignmentAstForDestructuringEntry(entry, localForTmpNode.name, localForTmpNode.typeFullName, idx + 1) + val assignmentsForEntries = nonUnderscoreDestructuringEntries(expr).zipWithIndex.map { case (entry, idx) => + val rhsBaseAst = + astWithRefEdgeMaybe( + localForTmpNode.name, + identifierNode(entry, localForTmpNode.name, localForTmpNode.name, localForTmpNode.typeFullName) + .argumentIndex(0) + ) + assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1) } localsForEntries ++ Seq(localForTmpAst) ++ Seq(tmpAssignmentAst) ++ tmpAssignmentPrologue ++ assignmentsForEntries @@ -349,33 +384,41 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { | -> CALL two = person.component2() |__________________________________ */ - private def astsForDestructuringDeclarationWithVarRHS( - expr: KtDestructuringDeclaration - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + private def astsForDestructuringDeclarationWithVarRHS(expr: KtDestructuringDeclaration): Seq[Ast] = { val typedInit = Option(expr.getInitializer).collect { case e: KtNameReferenceExpression => e } if (typedInit.isEmpty) { logger.warn(s"Unhandled case for destructuring declaration: `${expr.getText}` in this file `$relativizedPath`.") return Seq() } - val destructuringRHS = typedInit.get - val initTypeFullName = registerType(typeInfoProvider.typeFullName(typedInit.get, TypeConstants.any)) val assignmentsForEntries = nonUnderscoreDestructuringEntries(expr).zipWithIndex.map { case (entry, idx) => - assignmentAstForDestructuringEntry(entry, destructuringRHS.getText, initTypeFullName, idx + 1) + val rhsBaseAst = astForNameReference(typedInit.get, Some(1), None) + assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1) } - val localsForEntries = nonUnderscoreDestructuringEntries(expr).map { entry => - val typeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any)) - val node = localNode(entry, entry.getName, entry.getName, typeFullName) - scope.addToScope(node.name, node) - Ast(node) - } + val localsForEntries = localsForDestructuringEntries(expr) localsForEntries ++ assignmentsForEntries } - def astsForDestructuringDeclaration( - expr: KtDestructuringDeclaration - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + def localsForDestructuringEntries(destructuring: KtDestructuringDeclaration): Seq[Ast] = { + destructuring.getEntries.asScala + .filterNot(_.getText == Constants.UnusedDestructuringEntryText) + .map { entry => + val entryTypeFullName = registerType( + bindingUtils + .getVariableDesc(entry) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) + ) + val entryName = entry.getText + val node = localNode(entry, entryName, entryName, entryTypeFullName) + scope.addToScope(entryName, node) + Ast(node) + } + .toSeq + } + + def astsForDestructuringDeclaration(expr: KtDestructuringDeclaration): Seq[Ast] = { val hasNonRefExprRHS = expr.getInitializer match { case _: KtNameReferenceExpression => false case _: KtExpression => true @@ -385,31 +428,34 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { else astsForDestructuringDeclarationWithVarRHS(expr) } - private def componentNMethodAsts(typeDecl: NewTypeDecl, parameters: Seq[KtParameter])(implicit - typeInfoProvider: TypeInfoProvider - ): Seq[Ast] = { + private def componentNMethodAsts(typeDecl: NewTypeDecl, parameters: Seq[KtParameter]): Seq[Ast] = { parameters.zipWithIndex.map { case (valueParam, idx) => - val typeFullName = registerType(typeInfoProvider.typeFullName(valueParam, TypeConstants.any)) + val typeFullName = registerType( + bindingUtils + .getVariableDesc(valueParam) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) + ) val thisParam = NodeBuilders.newThisParameterNode(typeFullName = typeDecl.fullName, dynamicTypeHintFullName = Seq()) - val thisIdentifier = newIdentifierNode(Constants.this_, typeDecl.fullName, Seq(typeDecl.fullName)) + val thisIdentifier = newIdentifierNode(Constants.ThisName, typeDecl.fullName, Seq(typeDecl.fullName)) val thisAst = Ast(thisIdentifier).withRefEdge(thisIdentifier, thisParam) val fieldIdentifier = fieldIdentifierNode(valueParam, valueParam.getName, valueParam.getName) val fieldAccessCall = NodeBuilders.newOperatorCallNode( Operators.fieldAccess, - s"${Constants.this_}.${valueParam.getName}", + s"${Constants.ThisName}.${valueParam.getName}", Option(typeFullName) ) val fieldAccessCallAst = callAst(fieldAccessCall, List(thisAst, Ast(fieldIdentifier))) val methodBlockAst = blockAst( blockNode(valueParam, fieldAccessCall.code, typeFullName), - List(returnAst(returnNode(valueParam, Constants.ret), List(fieldAccessCallAst))) + List(returnAst(returnNode(valueParam, Constants.RetCode), List(fieldAccessCallAst))) ) val componentIdx = idx + 1 - val componentName = s"${Constants.componentNPrefix}$componentIdx" + val componentName = s"${Constants.ComponentNPrefix}$componentIdx" val signature = s"$typeFullName()" val fullName = s"${typeDecl.fullName}.$componentName:$signature" methodAst( @@ -421,22 +467,30 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { } } - private def secondaryCtorAsts(ctors: Seq[KtSecondaryConstructor], classFullName: String, primaryCtorCall: NewCall)( - implicit typeInfoProvider: TypeInfoProvider + private def secondaryCtorAsts( + ctors: Seq[KtSecondaryConstructor], + classFullName: String, + primaryCtorCall: NewCall ): Seq[Ast] = { ctors.map { ctor => - val primaryCtorCallAst = List(Ast(primaryCtorCall.copy)) - val constructorParams = ctor.getValueParameters.asScala.toList - val defaultSignature = typeInfoProvider.anySignature(constructorParams) - val defaultFullName = s"$classFullName.${TypeConstants.initPrefix}:$defaultSignature" - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(ctor, (defaultFullName, defaultSignature)) + val primaryCtorCallAst = List(Ast(primaryCtorCall.copy)) + val constructorParams = ctor.getValueParameters.asScala.toList + + val constructorDesc = bindingUtils.getConstructorDesc(ctor) + val descFullName = nameRenderer + .descFullName(constructorDesc) + .getOrElse(s"$classFullName.${Defines.ConstructorMethodName}") + val signature = nameRenderer + .funcDescSignature(constructorDesc) + .getOrElse(s"${Defines.UnresolvedSignature}(${ctor.getValueParameters.size()})") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) val secondaryCtorMethodNode = - methodNode(ctor, Constants.init, fullName, signature, relativizedPath) + methodNode(ctor, Defines.ConstructorMethodName, fullName, signature, relativizedPath) scope.pushNewScope(secondaryCtorMethodNode) val ctorThisParam = NodeBuilders.newThisParameterNode(typeFullName = classFullName, dynamicTypeHintFullName = Seq(classFullName)) - scope.addToScope(Constants.this_, ctorThisParam) + scope.addToScope(Constants.ThisName, ctorThisParam) val constructorParamsAsts = Seq(Ast(ctorThisParam)) ++ withIndex(constructorParams) { (p, idx) => astForParameter(p, idx) } @@ -446,19 +500,19 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { case b: KtBlockExpression => astsForBlock(b, None, None, preStatements = Option(primaryCtorCallAst)) case null => - val node = NewBlock().code(Constants.empty).typeFullName(TypeConstants.any) + val node = NewBlock().code(Constants.Empty).typeFullName(TypeConstants.Any) Seq(blockAst(node, primaryCtorCallAst)) } scope.popScope() val ctorMethodReturnNode = - newMethodReturnNode(TypeConstants.void, None, line(ctor), column(ctor)) + newMethodReturnNode(TypeConstants.Void, None, line(ctor), column(ctor)) // TODO: see if necessary to take the other asts for the ctorMethodBlock methodAst( secondaryCtorMethodNode, constructorParamsAsts, - ctorMethodBlockAsts.headOption.getOrElse(Ast(unknownNode(ctor.getBodyExpression, Constants.empty))), + ctorMethodBlockAsts.headOption.getOrElse(Ast(unknownNode(ctor.getBodyExpression, Constants.Empty))), ctorMethodReturnNode, Seq(newModifierNode(ModifierTypes.CONSTRUCTOR)) ) @@ -470,7 +524,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { argIdxMaybe: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val parentFn = KtPsiUtil.getTopmostParentOfTypes(expr, classOf[KtNamedFunction]) val ctx = Option(parentFn) @@ -483,7 +537,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val tmpName = s"tmp_obj_$idx" val typeDeclAsts = astsForClassOrObject(expr.getObjectDeclaration, Some(ctx)) - val typeDeclAst = typeDeclAsts.headOption.getOrElse(Ast(unknownNode(expr.getObjectDeclaration, Constants.empty))) + val typeDeclAst = typeDeclAsts.headOption.getOrElse(Ast(unknownNode(expr.getObjectDeclaration, Constants.Empty))) val typeDeclFullName = typeDeclAst.root.get.asInstanceOf[NewTypeDecl].fullName val localForTmp = localNode(expr, tmpName, tmpName, typeDeclFullName) @@ -503,17 +557,17 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { column(expr) ) val assignmentCallAst = callAst(assignmentNode, List(identifierAst) ++ List(rhsAst)) - val initSignature = s"${TypeConstants.void}()" - val initFullName = s"$typeDeclFullName.${TypeConstants.initPrefix}:$initSignature" + val initSignature = s"${TypeConstants.Void}()" + val initFullName = s"$typeDeclFullName.${Defines.ConstructorMethodName}:$initSignature" val initCallNode = callNode( expr, - Constants.init, - Constants.init, + Defines.ConstructorMethodName, + Defines.ConstructorMethodName, initFullName, DispatchTypes.STATIC_DISPATCH, Some(initSignature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) val initReceiverNode = identifierNode(expr, identifier.name, identifier.name, identifier.typeFullName) @@ -525,16 +579,14 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val refTmpAst = astWithRefEdgeMaybe(refTmpNode.name, refTmpNode) val blockNode_ = - withArgumentIndex(blockNode(expr, expr.getText, TypeConstants.any), argIdxMaybe) + withArgumentIndex(blockNode(expr, expr.getText, TypeConstants.Any), argIdxMaybe) .argumentName(argNameMaybe) blockAst(blockNode_, Seq(typeDeclAst, localAst, assignmentCallAst, initAst, refTmpAst).toList) .withChildren(annotations.map(astForAnnotationEntry)) } - def astsForProperty(expr: KtProperty, annotations: Seq[KtAnnotationEntry] = Seq())(implicit - typeInfoProvider: TypeInfoProvider - ): Seq[Ast] = { - val explicitTypeName = Option(expr.getTypeReference).map(_.getText).getOrElse(TypeConstants.any) + def astsForProperty(expr: KtProperty, annotations: Seq[KtAnnotationEntry] = Seq()): Seq[Ast] = { + val explicitTypeName = Option(expr.getTypeReference).map(_.getText).getOrElse(TypeConstants.Any) val elem = expr.getIdentifyingElement val hasRHSCtorCall = @@ -556,14 +608,20 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { case _ => false } if (ctorCallExprMaybe.nonEmpty) { - val callExpr = ctorCallExprMaybe.get - val localTypeFullName = registerType(typeInfoProvider.propertyType(expr, explicitTypeName)) - val local = localNode(expr, expr.getName, expr.getName, localTypeFullName) + val callExpr = ctorCallExprMaybe.get + val localTypeFullName = + bindingUtils + .getVariableDesc(expr) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .orElse(fullNameByImportPath(expr.getTypeReference, expr.getContainingKtFile)) + .getOrElse(explicitTypeName) + registerType(localTypeFullName) + val local = localNode(expr, expr.getName, expr.getName, localTypeFullName) scope.addToScope(expr.getName, local) val localAst = Ast(local) val typeFullName = registerType( - typeInfoProvider.expressionType(expr.getDelegateExpressionOrInitializer, Defines.UnresolvedNamespace) + exprTypeFullName(expr.getDelegateExpressionOrInitializer).getOrElse(Defines.UnresolvedNamespace) ) val rhsAst = Ast(NodeBuilders.newOperatorCallNode(Operators.alloc, Operators.alloc, Option(typeFullName))) @@ -575,23 +633,23 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val assignmentCallAst = callAst(assignmentNode, List(identifierAst) ++ List(rhsAst)) val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(callExpr, (TypeConstants.any, TypeConstants.any)) + calleeFullnameAndSignature( + getCalleeExpr(callExpr), + Defines.UnresolvedNamespace, + s"${Defines.UnresolvedSignature}(${callExpr.getValueArguments.size()})" + ) val initCallNode = callNode( callExpr, callExpr.getText, - Constants.init, + Defines.ConstructorMethodName, fullName, DispatchTypes.STATIC_DISPATCH, Some(signature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) val initReceiverNode = identifierNode(expr, identifier.name, identifier.name, identifier.typeFullName) val initReceiverAst = Ast(initReceiverNode).withRefEdge(initReceiverNode, local) - - val argAsts = withIndex(callExpr.getValueArguments.asScala.toSeq) { case (arg, idx) => - val argNameOpt = if (arg.isNamed) Option(arg.getArgumentName.getAsName.toString) else None - astsForExpression(arg.getArgumentExpression, Option(idx), argNameOpt) - }.flatten + val argAsts = astsForKtCallExpressionArguments(callExpr) val initAst = callAst(initCallNode, argAsts, Option(initReceiverAst)) @@ -608,7 +666,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val typeDeclAsts = astsForClassOrObject(typedExpr.getObjectDeclaration, Some(ctx)) val typeDeclAst = - typeDeclAsts.headOption.getOrElse(Ast(unknownNode(typedExpr.getObjectDeclaration, Constants.empty))) + typeDeclAsts.headOption.getOrElse(Ast(unknownNode(typedExpr.getObjectDeclaration, Constants.Empty))) val typeDeclFullName = typeDeclAst.root.get.asInstanceOf[NewTypeDecl].fullName val node = localNode(expr, expr.getName, expr.getName, typeDeclFullName) @@ -616,7 +674,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val localAst = Ast(node) val typeFullName = registerType( - typeInfoProvider.expressionType(expr.getDelegateExpressionOrInitializer, Defines.UnresolvedNamespace) + exprTypeFullName(expr.getDelegateExpressionOrInitializer).getOrElse(Defines.UnresolvedNamespace) ) val rhsAst = Ast(NodeBuilders.newOperatorCallNode(Operators.alloc, Operators.alloc, None)) @@ -626,16 +684,16 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { val assignmentNode = NodeBuilders.newOperatorCallNode(Operators.assignment, expr.getText, None, line(expr), column(expr)) val assignmentCallAst = callAst(assignmentNode, List(identifierAst) ++ List(rhsAst)) - val initSignature = s"${TypeConstants.void}()" - val initFullName = s"$typeFullName${TypeConstants.initPrefix}:$initSignature" + val initSignature = s"${TypeConstants.Void}()" + val initFullName = s"$typeFullName${Defines.ConstructorMethodName}:$initSignature" val initCallNode = callNode( expr, - Constants.init, - Constants.init, + Defines.ConstructorMethodName, + Defines.ConstructorMethodName, initFullName, DispatchTypes.STATIC_DISPATCH, Some(initSignature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) val initReceiverNode = identifierNode(expr, identifier.name, identifier.name, identifier.typeFullName) val initReceiverAst = Ast(initReceiverNode).withRefEdge(initReceiverNode, node) @@ -645,32 +703,46 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { .withChildren(annotations.map(astForAnnotationEntry)) Seq(typeDeclAst, localAst, assignmentCallAst, initAst) } else { - val typeFullName = registerType(typeInfoProvider.propertyType(expr, explicitTypeName)) - val node = localNode(expr, expr.getName, expr.getName, typeFullName) + val typeFullName = bindingUtils + .getVariableDesc(expr) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .orElse(fullNameByImportPath(expr.getTypeReference, expr.getContainingKtFile)) + .getOrElse(explicitTypeName) + registerType(typeFullName) + val node = localNode(expr, expr.getName, expr.getName, typeFullName) scope.addToScope(expr.getName, node) val localAst = Ast(node) - val rhsAsts = astsForExpression(expr.getDelegateExpressionOrInitializer, Some(2)) - val identifier = identifierNode(elem, elem.getText, elem.getText, typeFullName) - val identifierAst = astWithRefEdgeMaybe(identifier.name, identifier) - val assignmentNode = - NodeBuilders.newOperatorCallNode(Operators.assignment, expr.getText, None, line(expr), column(expr)) - val call = - callAst(assignmentNode, List(identifierAst) ++ rhsAsts) - .withChildren(annotations.map(astForAnnotationEntry)) - Seq(localAst, call) + if (expr.getDelegateExpressionOrInitializer != null) { + val rhsAsts = astsForExpression(expr.getDelegateExpressionOrInitializer, Some(2)) + val identifier = identifierNode(elem, elem.getText, elem.getText, typeFullName) + val identifierAst = astWithRefEdgeMaybe(identifier.name, identifier) + val assignmentNode = + NodeBuilders.newOperatorCallNode(Operators.assignment, expr.getText, None, line(expr), column(expr)) + val call = + callAst(assignmentNode, List(identifierAst) ++ rhsAsts) + .withChildren(annotations.map(astForAnnotationEntry)) + Seq(localAst, call) + } else { + Seq(localAst) + } } } - private def astForMember(decl: KtDeclaration)(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val name = Option(decl.getName).getOrElse(TypeConstants.any) + private def astForMember(decl: KtDeclaration): Ast = { + val name = Option(decl.getName).getOrElse(TypeConstants.Any) val explicitTypeName = decl.getOriginalElement match { case p: KtProperty if p.getTypeReference != null => p.getTypeReference.getText - case _ => TypeConstants.any + case _ => TypeConstants.Any } val typeFullName = decl match { - case typed: KtProperty => typeInfoProvider.propertyType(typed, explicitTypeName) - case _ => explicitTypeName + case typed: KtProperty => + bindingUtils + .getVariableDesc(typed) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .orElse(fullNameByImportPath(typed.getTypeReference, typed.getContainingKtFile)) + .getOrElse(explicitTypeName) + case _ => explicitTypeName } registerType(typeFullName) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForExpressionsCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForExpressionsCreator.scala index 6a6ee736ad36..bd8a5f4754aa 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForExpressionsCreator.scala @@ -3,7 +3,6 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.types.CallKind import io.joern.kotlin2cpg.types.TypeConstants -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.Ast import io.joern.x2cpg.Defines import io.joern.x2cpg.ValidationMode @@ -11,6 +10,8 @@ import io.joern.x2cpg.utils.NodeBuilders import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.NewMethodRef +import org.jetbrains.kotlin.descriptors.DescriptorVisibilities +import org.jetbrains.kotlin.descriptors.FunctionDescriptor import org.jetbrains.kotlin.lexer.KtToken import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.psi.* @@ -25,7 +26,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + ): Seq[Ast] = { val opRef = expr.getOperationReference // TODO: add the rest of the operators @@ -71,18 +72,37 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { logger.warn( s"Unhandled operator token type `${opRef.getOperationSignTokenType}` for expression `${expr.getText}` in this file `$relativizedPath`." ) - Some(Constants.unknownOperator) + Some(Constants.UnknownOperator) } val (fullName, signature) = - if (operatorOption.isDefined) (operatorOption.get, TypeConstants.any) + if (operatorOption.isDefined) (operatorOption.get, TypeConstants.Any) // TODO: fix the fallback METHOD_FULL_NAME and SIGNATURE here (should be a correct number of ANYs) - else typeInfoProvider.fullNameWithSignature(expr, (TypeConstants.any, TypeConstants.any)) + else { + val funcDesc = bindingUtils.getCalledFunctionDesc(expr.getOperationReference) + val descFullName = funcDesc + .orElse(getAmbiguousFuncDescIfFullNamesEqual(expr.getOperationReference)) + .flatMap(nameRenderer.descFullName) + .getOrElse(TypeConstants.Any) + val signature = funcDesc + .orElse(getAmbiguousFuncDescIfSignaturesEqual(expr.getOperationReference)) + .flatMap(nameRenderer.funcDescSignature) + .getOrElse(TypeConstants.Any) + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + (fullName, signature) + } val finalSignature = // TODO: add test case for this situation - if (fullName.startsWith(Constants.operatorSuffix)) Constants.empty + if (fullName.startsWith(Constants.OperatorSuffix)) Constants.Empty else signature - val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any)) + + val typeFullName = registerType( + bindingUtils + .getCalledFunctionDesc(expr.getOperationReference) + .orElse(getAmbiguousFuncDescIfSignaturesEqual(expr.getOperationReference)) + .flatMap(funcDesc => nameRenderer.typeFullName(funcDesc.getOriginal.getReturnType)) + .getOrElse(TypeConstants.Any) + ) val name = if (operatorOption.isDefined) operatorOption.get else if (expr.getChildren.toList.sizeIs >= 2) expr.getChildren.toList(1).getText @@ -102,8 +122,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { callAst( withArgumentIndex(node, argIdx).argumentName(argNameMaybe), List( - lhsArgs.lastOption.getOrElse(Ast(unknownNode(expr.getLeft, Constants.empty))), - rhsArgs.lastOption.getOrElse(Ast(unknownNode(expr.getRight, Constants.empty))) + lhsArgs.lastOption.getOrElse(Ast(unknownNode(expr.getLeft, Constants.Empty))), + rhsArgs.lastOption.getOrElse(Ast(unknownNode(expr.getRight, Constants.Empty))) ) ) .withChildren(annotations.map(astForAnnotationEntry)) @@ -114,33 +134,38 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { expr: KtQualifiedExpression, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val receiverAst = astsForExpression(expr.getReceiverExpression, Some(1)).headOption - .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.empty))) - val argAsts = selectorExpressionArgAsts(expr) - registerType(typeInfoProvider.containingDeclType(expr, TypeConstants.any)) - val retType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + val exprNode = astsForExpression(expr.getReceiverExpression, Some(1)).headOption + .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.Empty))) + + val nameReferenceExpr = expr.getSelectorExpression.asInstanceOf[KtNameReferenceExpression] + val fieldIdentifier = Ast( + fieldIdentifierNode(nameReferenceExpr, nameReferenceExpr.getText, nameReferenceExpr.getText).argumentIndex(2) + ) + + val retType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = withArgumentIndex( NodeBuilders.newOperatorCallNode(Operators.fieldAccess, expr.getText, Option(retType), line(expr), column(expr)), argIdx ).argumentName(argNameMaybe) - callAst(node, List(receiverAst) ++ argAsts) + callAst(node, List(exprNode, fieldIdentifier)) } private def astForQualifiedExpressionExtensionCall( expr: KtQualifiedExpression, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val receiverAst = astsForExpression(expr.getReceiverExpression, Some(0)).headOption - .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.empty))) - val argAsts = selectorExpressionArgAsts(expr) + ): Ast = { + val argAsts = selectorExpressionArgAsts(expr, 2) + + // TODO fix the cast to KtCallExpression + val (fullName, signature) = calleeFullnameAndSignature( + getCalleeExpr(expr), + astDerivedFullNameWithSignature(expr, argAsts)._1, + astDerivedFullNameWithSignature(expr, argAsts)._2 + ) - val (astDerivedMethodFullName, astDerivedSignature) = astDerivedFullNameWithSignature(expr, argAsts) - val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(expr, (astDerivedMethodFullName, astDerivedSignature)) - registerType(typeInfoProvider.containingDeclType(expr, TypeConstants.any)) - val retType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val retType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val methodName = expr.getSelectorExpression.getFirstChild.getText val node = withArgumentIndex( @@ -155,23 +180,28 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ), argIdx ).argumentName(argNameMaybe) - callAst(node, argAsts, Option(receiverAst)) + + val instanceArg = astsForExpression(expr.getReceiverExpression, Some(1)).headOption + .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.Empty))) + callAst(node, instanceArg +: argAsts) } private def astForQualifiedExpressionCallToSuper( expr: KtQualifiedExpression, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val receiverAst = astsForExpression(expr.getReceiverExpression, Some(0)).headOption - .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.empty))) + .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.Empty))) val argAsts = selectorExpressionArgAsts(expr) - val (astDerivedMethodFullName, astDerivedSignature) = astDerivedFullNameWithSignature(expr, argAsts) - val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(expr, (astDerivedMethodFullName, astDerivedSignature)) - registerType(typeInfoProvider.containingDeclType(expr, TypeConstants.any)) - val retType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val (fullName, signature) = calleeFullnameAndSignature( + getCalleeExpr(expr), + astDerivedFullNameWithSignature(expr, argAsts)._1, + astDerivedFullNameWithSignature(expr, argAsts)._2 + ) + + val retType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val methodName = expr.getSelectorExpression.getFirstChild.getText val node = withArgumentIndex( @@ -193,16 +223,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { expr: KtQualifiedExpression, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { expr.getSelectorExpression match { case callExpr: KtCallExpression => val localName = "tmp" - val localTypeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val localTypeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val local = localNode(expr, localName, localName, localTypeFullName) scope.addToScope(localName, local) val localAst = Ast(local) - val typeFullName = registerType(typeInfoProvider.expressionType(expr, Defines.UnresolvedNamespace)) + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(Defines.UnresolvedNamespace)) val rhsAst = Ast(NodeBuilders.newOperatorCallNode(Operators.alloc, Operators.alloc, Option(typeFullName))) val identifier = identifierNode(expr, localName, localName, local.typeFullName) @@ -218,23 +248,24 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val assignmentCallAst = callAst(assignmentNode, List(identifierAst) ++ List(rhsAst)) val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(callExpr, (TypeConstants.any, TypeConstants.any)) + calleeFullnameAndSignature( + getCalleeExpr(expr), + Defines.UnresolvedNamespace, + s"${Defines.UnresolvedSignature}(${callExpr.getValueArguments.size()})" + ) val initCallNode = callNode( callExpr, callExpr.getText, - Constants.init, + Defines.ConstructorMethodName, fullName, DispatchTypes.STATIC_DISPATCH, Some(signature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) val initReceiverNode = identifierNode(expr, identifier.name, identifier.name, identifier.typeFullName) val initReceiverAst = Ast(initReceiverNode).withRefEdge(initReceiverNode, local) - val argAsts = withIndex(callExpr.getValueArguments.asScala.toSeq) { case (arg, idx) => - val argNameOpt = if (arg.isNamed) Option(arg.getArgumentName.getAsName.toString) else None - astsForExpression(arg.getArgumentExpression, Option(idx), argNameOpt) - }.flatten + val argAsts = astsForKtCallExpressionArguments(callExpr) val initAst = callAst(initCallNode, argAsts, Option(initReceiverAst)) val returningIdentifierNode = identifierNode(expr, identifier.name, identifier.name, identifier.typeFullName) @@ -247,7 +278,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } blockAst(node, List(localAst, assignmentCallAst, initAst, returningIdentifierAst)) case _ => - val node = blockNode(expr, "", TypeConstants.any).argumentName(argNameMaybe) + val node = blockNode(expr, "", TypeConstants.Any).argumentName(argNameMaybe) argIdx match { case Some(idx) => node.argumentIndex(idx) case _ => @@ -260,16 +291,17 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { expr: KtQualifiedExpression, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val receiverAst = astsForExpression(expr.getReceiverExpression, Some(1)).headOption - .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.empty))) + .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.Empty))) val argAsts = selectorExpressionArgAsts(expr) - val (astDerivedMethodFullName, astDerivedSignature) = astDerivedFullNameWithSignature(expr, argAsts) - val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(expr, (astDerivedMethodFullName, astDerivedSignature)) - registerType(typeInfoProvider.containingDeclType(expr, TypeConstants.any)) - val retType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val (fullName, signature) = calleeFullnameAndSignature( + getCalleeExpr(expr), + astDerivedFullNameWithSignature(expr, argAsts)._1, + astDerivedFullNameWithSignature(expr, argAsts)._2 + ) + val retType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val methodName = expr.getSelectorExpression.getFirstChild.getText val dispatchType = DispatchTypes.STATIC_DISPATCH @@ -288,7 +320,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { callKind: CallKind, argIdx: Option[Int], argNameMaybe: Option[String] - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val isDynamicCall = callKind == CallKind.DynamicCall val isStaticCall = callKind == CallKind.StaticCall val argIdxForReceiver = @@ -300,14 +332,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { else DispatchTypes.STATIC_DISPATCH val receiverAst = astsForExpression(expr.getReceiverExpression, Some(argIdxForReceiver)).headOption - .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.empty))) + .getOrElse(Ast(unknownNode(expr.getReceiverExpression, Constants.Empty))) val argAsts = selectorExpressionArgAsts(expr) - val (astDerivedMethodFullName, astDerivedSignature) = astDerivedFullNameWithSignature(expr, argAsts) - val (fullName, signature) = - typeInfoProvider.fullNameWithSignature(expr, (astDerivedMethodFullName, astDerivedSignature)) - registerType(typeInfoProvider.containingDeclType(expr, TypeConstants.any)) - val retType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val (fullName, signature) = calleeFullnameAndSignature( + getCalleeExpr(expr), + astDerivedFullNameWithSignature(expr, argAsts)._1, + astDerivedFullNameWithSignature(expr, argAsts)._2 + ) + val retType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val methodName = expr.getSelectorExpression.getFirstChild.getText val node = @@ -337,18 +370,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val callKind = typeInfoProvider.bindingKind(expr) val isExtensionCall = callKind == CallKind.ExtensionCall - val hasThisSuperOrNameRefReceiver = expr.getReceiverExpression match { - case _: KtThisExpression => true - case _: KtNameReferenceExpression => true - case _: KtSuperExpression => true - case _ => false - } val hasNameRefSelector = expr.getSelectorExpression.isInstanceOf[KtNameReferenceExpression] - val isFieldAccessCall = hasThisSuperOrNameRefReceiver && hasNameRefSelector val isCallToSuper = expr.getReceiverExpression match { case _: KtSuperExpression => true case _ => false @@ -366,10 +392,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val outAst = if (isCtorCtorCall.getOrElse(false)) { astForQualifiedExpressionCtor(expr, argIdx, argNameMaybe) - } else if (isFieldAccessCall) { - astForQualifiedExpressionFieldAccess(expr, argIdx, argNameMaybe) } else if (isExtensionCall) { astForQualifiedExpressionExtensionCall(expr, argIdx, argNameMaybe) + } else if (hasNameRefSelector) { + astForQualifiedExpressionFieldAccess(expr, argIdx, argNameMaybe) } else if (isCallToSuper) { astForQualifiedExpressionCallToSuper(expr, argIdx, argNameMaybe) } else if (noAstForReceiver) { @@ -385,8 +411,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val args = astsForExpression(expr.getLeftHandSide, None) ++ Seq(astForTypeReference(expr.getTypeReference, None, argName)) val node = NodeBuilders.newOperatorCallNode(Operators.is, expr.getText, None, line(expr), column(expr)) @@ -399,8 +425,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val args = astsForExpression(expr.getLeft, None) ++ Seq(astForTypeReference(expr.getRight, None, None)) val node = NodeBuilders.newOperatorCallNode(Operators.cast, expr.getText, None, line(expr), column(expr)) callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toList) @@ -412,7 +438,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { + ): Seq[Ast] = { val isCtorCall = typeInfoProvider.isConstructorCall(expr) if (isCtorCall.getOrElse(false)) astsForCtorCall(expr, argIdx, argNameMaybe, annotations) else astsForNonCtorCall(expr, argIdx, argNameMaybe, annotations) @@ -423,14 +449,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { - val declFullNameOption = typeInfoProvider.containingDeclFullName(expr) - declFullNameOption.foreach(registerType) - - val argAsts = withIndex(expr.getValueArguments.asScala.toSeq) { case (arg, idx) => - val argNameOpt = if (arg.isNamed) Option(arg.getArgumentName.getAsName.toString) else None - astsForExpression(arg.getArgumentExpression, Option(idx), argNameOpt) - }.flatten + ): Seq[Ast] = { + val argAsts = astsForKtCallExpressionArguments(expr) // TODO: add tests for the empty `referencedName` here val referencedName = Option(expr.getFirstChild) @@ -442,7 +462,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val imports = expr.getContainingKtFile.getImportList.getImports.asScala.toList val importedNames = imports.map { imp => - val importedName = Option(imp.getImportedName).map(_.toString).getOrElse(Constants.wildcardImportName) + val importedName = Option(imp.getImportedName).map(_.toString).getOrElse(Constants.WildcardImportName) importedName -> imp }.toMap @@ -454,30 +474,78 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } else { s"${expr.getContainingKtFile.getPackageFqName.toString}.$referencedName" } - lazy val typeArgs = - expr.getTypeArguments.asScala.map(x => typeInfoProvider.typeFullName(x.getTypeReference, TypeConstants.any)) - val explicitSignature = s"${TypeConstants.any}(${argAsts.map { _ => TypeConstants.any }.mkString(",")})" - val explicitFullName = - if (typeInfoProvider.typeRenderer.keepTypeArguments && typeArgs.nonEmpty) - s"$methodFqName<${typeArgs.mkString(",")}>:$explicitSignature" - else s"$methodFqName:$explicitSignature" - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(expr, (explicitFullName, explicitSignature)) + val explicitSignature = s"${Defines.UnresolvedSignature}(${argAsts.size})" + val explicitFullName = methodFqName + + val funcDesc = bindingUtils.getCalledFunctionDesc(expr.getCalleeExpression) + val descFullName = funcDesc + .orElse(getAmbiguousFuncDescIfFullNamesEqual(expr.getCalleeExpression)) + .flatMap(nameRenderer.descFullName) + .getOrElse(explicitFullName) + val signature = funcDesc + .orElse(getAmbiguousFuncDescIfSignaturesEqual(expr.getCalleeExpression)) + .flatMap(nameRenderer.funcDescSignature) + .getOrElse(explicitSignature) + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + val resolvedCall = bindingUtils.getResolvedCallDesc(expr.getCalleeExpression) + + val (dispatchType, instanceAsArgument) = + if (resolvedCall.isEmpty) { + (DispatchTypes.STATIC_DISPATCH, false) + } else { + if (resolvedCall.get.getDispatchReceiver == null) { + (DispatchTypes.STATIC_DISPATCH, false) + } else { + resolvedCall.get.getResultingDescriptor match { + case functionDescriptor: FunctionDescriptor + if functionDescriptor.getVisibility == DescriptorVisibilities.PRIVATE => + (DispatchTypes.STATIC_DISPATCH, true) + case _ => + (DispatchTypes.DYNAMIC_DISPATCH, true) + } + } + } // TODO: add test case to confirm whether the ANY fallback makes sense (could be void) - val returnType = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) - val node = callNode( - expr, - expr.getText, - referencedName, - fullName, - DispatchTypes.STATIC_DISPATCH, - Some(signature), - Some(returnType) - ) + val returnType = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) + val node = callNode(expr, expr.getText, referencedName, fullName, dispatchType, Some(signature), Some(returnType)) + val annotationsAsts = annotations.map(astForAnnotationEntry) val astWithAnnotations = - callAst(withArgumentIndex(node, argIdx).argumentName(argNameMaybe), argAsts.toList) - .withChildren(annotationsAsts) + if (dispatchType == DispatchTypes.STATIC_DISPATCH) { + val compoundArgAsts = + if (instanceAsArgument) { + val instanceArgument = identifierNode( + expr, + Constants.ThisName, + Constants.ThisName, + nameRenderer.typeFullName(resolvedCall.get.getDispatchReceiver.getType).getOrElse(TypeConstants.Any) + ) + val args = argAsts.prepended(Ast(instanceArgument)) + setArgumentIndices(args, 0) + args + } else { + setArgumentIndices(argAsts) + argAsts + } + + Ast(withArgumentIndex(node, argIdx).argumentName(argNameMaybe)) + .withChildren(compoundArgAsts) + .withArgEdges(node, compoundArgAsts.flatMap(_.root)) + .withChildren(annotationsAsts) + } else { + val receiverNode = identifierNode( + expr, + Constants.ThisName, + Constants.ThisName, + nameRenderer.typeFullName(resolvedCall.get.getDispatchReceiver.getType).getOrElse(TypeConstants.Any) + ) + + callAst(withArgumentIndex(node, argIdx).argumentName(argNameMaybe), argAsts, base = Some(Ast(receiverNode))) + .withChildren(annotationsAsts) + } + List(astWithAnnotations) } @@ -486,16 +554,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, Defines.UnresolvedNamespace)) + ): Seq[Ast] = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(Defines.UnresolvedNamespace)) val tmpBlockNode = blockNode(expr, "", typeFullName) - val tmpName = s"${Constants.tmpLocalPrefix}${tmpKeyPool.next}" + val tmpName = s"${Constants.TmpLocalPrefix}${tmpKeyPool.next}" val tmpLocalNode = localNode(expr, tmpName, tmpName, typeFullName) scope.addToScope(tmpName, tmpLocalNode) val tmpLocalAst = Ast(tmpLocalNode) val assignmentRhsNode = - NodeBuilders.newOperatorCallNode(Operators.alloc, Constants.alloc, Option(typeFullName), line(expr), column(expr)) + NodeBuilders.newOperatorCallNode(Operators.alloc, Constants.Alloc, Option(typeFullName), line(expr), column(expr)) val assignmentLhsNode = identifierNode(expr, tmpName, tmpName, typeFullName) val assignmentLhsAst = astWithRefEdgeMaybe(tmpName, assignmentLhsNode) @@ -509,22 +577,27 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { withIndex(expr.getValueArguments.asScala.toSeq) { case (arg, idx) => val argNameOpt = if (arg.isNamed) Option(arg.getArgumentName.getAsName.toString) else None val asts = astsForExpression(arg.getArgumentExpression, Option(idx), argNameOpt) - (asts.dropRight(1), asts.lastOption.getOrElse(Ast(unknownNode(arg.getArgumentExpression, Constants.empty)))) + (asts.dropRight(1), asts.lastOption.getOrElse(Ast(unknownNode(arg.getArgumentExpression, Constants.Empty)))) } val astsForTrails = argAstsWithTrail.map(_._2) val astsForNonTrails = argAstsWithTrail.flatMap(_._1) - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(expr, (TypeConstants.any, TypeConstants.any)) - registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val (fullName, signature) = + calleeFullnameAndSignature( + getCalleeExpr(expr), + Defines.UnresolvedNamespace, + s"${Defines.UnresolvedSignature}(${expr.getValueArguments.size()})" + ) + registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val initCallNode = callNode( expr, expr.getText, - Constants.init, + Defines.ConstructorMethodName, fullName, DispatchTypes.STATIC_DISPATCH, Some(signature), - Some(TypeConstants.void) + Some(TypeConstants.Void) ) val initCallAst = callAst(initCallNode, astsForTrails, Option(initReceiverAst)) val lastIdentifier = identifierNode(expr, tmpName, tmpName, typeFullName) @@ -544,15 +617,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val operatorType = ktTokenToOperator(forPostfixExpr = true).applyOrElse( KtPsiUtil.getOperationToken(expr), { (token: KtToken) => logger.warn(s"Unsupported token type encountered: $token") - Constants.unknownOperator + Constants.UnknownOperator } ) - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val args = List(astsForExpression(expr.getBaseExpression, None).headOption.getOrElse(Ast())) .filterNot(_.root == null) val node = @@ -566,15 +639,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val operatorType = ktTokenToOperator(forPostfixExpr = false).applyOrElse( KtPsiUtil.getOperationToken(expr), { (token: KtToken) => logger.warn(s"Unsupported token type encountered: $token") - Constants.unknownOperator + Constants.UnknownOperator } ) - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val args = List(astsForExpression(expr.getBaseExpression, None).headOption.getOrElse(Ast())) .filterNot(_.root == null) val node = @@ -588,9 +661,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val arrayExpr = expression.getArrayExpression - val typeFullName = registerType(typeInfoProvider.expressionType(expression, TypeConstants.any)) + val typeFullName = registerType(exprTypeFullName(expression).getOrElse(TypeConstants.Any)) val identifier = identifierNode(arrayExpr, arrayExpr.getText, arrayExpr.getText, typeFullName) val identifierAst = astWithRefEdgeMaybe(arrayExpr.getText, identifier) val astsForIndexExpr = expression.getIndexExpressions.asScala.zipWithIndex.flatMap { case (expr, idx) => diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForFunctionsCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForFunctionsCreator.scala index 45ceb2aa6850..36593c8c9a01 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForFunctionsCreator.scala @@ -2,33 +2,42 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.types.TypeConstants -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.Ast +import io.joern.x2cpg.Defines import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Stack.StackWrapper import io.joern.x2cpg.utils.NodeBuilders import io.joern.x2cpg.utils.NodeBuilders.newBindingNode import io.joern.x2cpg.utils.NodeBuilders.newClosureBindingNode +import io.joern.x2cpg.utils.NodeBuilders.newIdentifierNode import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode import io.joern.x2cpg.utils.NodeBuilders.newModifierNode -import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.EvaluationStrategies import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -import org.jetbrains.kotlin.descriptors.DescriptorVisibilities +import org.jetbrains.kotlin.com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.ClassDescriptor +import org.jetbrains.kotlin.descriptors.FunctionDescriptor import org.jetbrains.kotlin.descriptors.Modality +import org.jetbrains.kotlin.descriptors.ParameterDescriptor import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.resolve.calls.model.ResolvedCallArgument +import org.jetbrains.kotlin.resolve.calls.tower.NewAbstractResolvedCall +import org.jetbrains.kotlin.resolve.calls.tower.PSIFunctionKotlinCallArgument +import org.jetbrains.kotlin.resolve.sam.SamConstructorDescriptor +import org.jetbrains.kotlin.resolve.sam.SamConversionResolverImplKt +import org.jetbrains.kotlin.resolve.source.KotlinSourceElement import java.util.UUID.nameUUIDFromBytes +import scala.collection.mutable import scala.jdk.CollectionConverters.* trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - private def isAbstract(ktFn: KtNamedFunction)(implicit typeInfoProvider: TypeInfoProvider): Boolean = { - typeInfoProvider.modality(ktFn).contains(Modality.ABSTRACT) - } + import AstCreator.ClosureBindingDef private def createFunctionTypeAndTypeDeclAst( node: KtNamedFunction, @@ -51,7 +60,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { methodName, astParentType = astParentType, astParentFullName = astParentFullName, - Seq(TypeConstants.kotlinFunctionXPrefix) + Seq(TypeConstants.KotlinFunctionPrefix) ) if (astParentName == NamespaceTraversal.globalNamespaceName || astParentType == Method.Label) { // Bindings for others (classes, interfaces, and such) are already created in their respective CPG creation functions @@ -62,67 +71,99 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { } } - def astsForMethod(ktFn: KtNamedFunction, needsThisParameter: Boolean = false, withVirtualModifier: Boolean = false)( - implicit typeInfoProvider: TypeInfoProvider - ): Seq[Ast] = { - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(ktFn, ("", "")) - val _methodNode = methodNode(ktFn, ktFn.getName, fullName, signature, relativizedPath) + def astsForMethod(ktFn: KtNamedFunction, withVirtualModifier: Boolean = false): Seq[Ast] = { + val funcDesc = bindingUtils.getFunctionDesc(ktFn) + val descFullName = nameRenderer + .descFullName(funcDesc) + .getOrElse(s"${Defines.UnresolvedNamespace}.${ktFn.getName}") + val signature = nameRenderer + .funcDescSignature(funcDesc) + .getOrElse(s"${Defines.UnresolvedSignature}(${ktFn.getValueParameters.size()})") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + val _methodNode = methodNode(ktFn, ktFn.getName, fullName, signature, relativizedPath) scope.pushNewScope(_methodNode) methodAstParentStack.push(_methodNode) - val thisParameterMaybe = if (needsThisParameter) { - val typeDeclFullName = registerType(typeInfoProvider.containingTypeDeclFullName(ktFn, TypeConstants.any)) - val node = NodeBuilders.newThisParameterNode( + val isExtensionMethod = funcDesc.getExtensionReceiverParameter != null + + val needsThisParameter = funcDesc.getDispatchReceiverParameter != null || + isExtensionMethod + + val thisParameterAsts = if (needsThisParameter) { + val typeDeclFullName = + if (funcDesc.getDispatchReceiverParameter != null) { + nameRenderer.typeFullName(funcDesc.getDispatchReceiverParameter.getType).getOrElse(TypeConstants.Any) + } else { + nameRenderer.typeFullName(funcDesc.getExtensionReceiverParameter.getType).getOrElse(TypeConstants.Any) + } + + registerType(typeDeclFullName) + + val thisParam = NodeBuilders.newThisParameterNode( typeFullName = typeDeclFullName, dynamicTypeHintFullName = Seq(typeDeclFullName) ) - scope.addToScope(Constants.this_, node) - Option(node) - } else None + if (isExtensionMethod) { + thisParam.order(1) + thisParam.index(1) + } + scope.addToScope(Constants.ThisName, thisParam) + List(Ast(thisParam)) + } else { + List.empty + } + + val valueParamStartIndex = + if (isExtensionMethod) { + 2 + } else { + 1 + } - val thisParameterAsts = thisParameterMaybe.map(List(_)).getOrElse(List()).map(Ast(_)) val methodParametersAsts = - withIndex(ktFn.getValueParameters.asScala.toSeq) { (p, idx) => astForParameter(p, idx) } + withIndex(ktFn.getValueParameters.asScala.toSeq) { (p, idx) => + astForParameter(p, valueParamStartIndex + idx - 1) + } val bodyAsts = Option(ktFn.getBodyBlockExpression) match { case Some(bodyBlockExpression) => astsForBlock(bodyBlockExpression, None, None) case None => Option(ktFn.getBodyExpression) .map { expr => - val bodyBlock = blockNode(expr, expr.getText, TypeConstants.any) + val bodyBlock = blockNode(expr, expr.getText, TypeConstants.Any) val asts = astsForExpression(expr, Some(1)) val blockChildAsts = if (asts.nonEmpty) { val allStatementsButLast = asts.dropRight(1) - val lastStatementAst = asts.lastOption.getOrElse(Ast(unknownNode(expr, Constants.empty))) - val returnAst_ = returnAst(returnNode(expr, Constants.retCode), Seq(lastStatementAst)) + val lastStatementAst = asts.lastOption.getOrElse(Ast(unknownNode(expr, Constants.Empty))) + val returnAst_ = returnAst(returnNode(expr, Constants.RetCode), Seq(lastStatementAst)) (allStatementsButLast ++ Seq(returnAst_)).toList } else List() Seq(blockAst(bodyBlock, blockChildAsts)) } .getOrElse { - val bodyBlock = blockNode(ktFn, "", TypeConstants.any) + val bodyBlock = blockNode(ktFn, "", TypeConstants.Any) Seq(blockAst(bodyBlock, List[Ast]())) } } methodAstParentStack.pop() scope.popScope() - val bodyAst = bodyAsts.headOption.getOrElse(Ast(unknownNode(ktFn, Constants.empty))) + val bodyAst = bodyAsts.headOption.getOrElse(Ast(unknownNode(ktFn, Constants.Empty))) val otherBodyAsts = bodyAsts.drop(1) - val explicitTypeName = Option(ktFn.getTypeReference).map(_.getText).getOrElse(TypeConstants.any) - val typeFullName = registerType(typeInfoProvider.returnType(ktFn, explicitTypeName)) + val explicitTypeName = Option(ktFn.getTypeReference).map(_.getText).getOrElse(TypeConstants.Any) + val typeFullName = registerType(nameRenderer.typeFullName(funcDesc.getReturnType).getOrElse(explicitTypeName)) val _methodReturnNode = newMethodReturnNode(typeFullName, None, line(ktFn), column(ktFn)) - val visibility = typeInfoProvider.visibility(ktFn) val visibilityModifierType = - modifierTypeForVisibility(visibility.getOrElse(DescriptorVisibilities.UNKNOWN)) + modifierTypeForVisibility(funcDesc.getVisibility) val visibilityModifier = NodeBuilders.newModifierNode(visibilityModifierType) val modifierNodes = if (withVirtualModifier) Seq(NodeBuilders.newModifierNode(ModifierTypes.VIRTUAL)) else Seq() - val modifiers = if (isAbstract(ktFn)) { + val modifiers = if (funcDesc.getModality == Modality.ABSTRACT) { List(visibilityModifier) ++ modifierNodes :+ NodeBuilders.newModifierNode(ModifierTypes.ABSTRACT) } else { List(visibilityModifier) ++ modifierNodes @@ -147,16 +188,88 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { ) } - def astForParameter(param: KtParameter, order: Int)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astsForDestructuring(param: KtParameter): Seq[Ast] = { + val decl = param.getDestructuringDeclaration + val tmpName = s"${Constants.TmpLocalPrefix}${tmpKeyPool.next}" + var localForTmp = Option.empty[NewLocal] + val additionalLocals = mutable.ArrayBuffer.empty[Ast] + + val initCallAst = if (decl.hasInitializer) { + val init = decl.getInitializer + val asts = astsForExpression(init, Some(2)) + val initAst = + if (asts.size == 1) { asts.head } + else { + val block = blockNode(init, "", "") + blockAst(block, asts.toList) + } + val local = localNode(decl, tmpName, tmpName, TypeConstants.Any) + localForTmp = Some(local) + scope.addToScope(tmpName, local) + val tmpIdentifier = newIdentifierNode(tmpName, TypeConstants.Any) + val tmpIdentifierAst = Ast(tmpIdentifier).withRefEdge(tmpIdentifier, local) + val assignmentCallNode = NodeBuilders.newOperatorCallNode( + Operators.assignment, + s"$tmpName = ${init.getText}", + None, + line(init), + column(init) + ) + callAst(assignmentCallNode, List(tmpIdentifierAst, initAst)) + } else { + val explicitTypeName = Option(param.getTypeReference) + .map(typeRef => fullNameByImportPath(typeRef, param.getContainingKtFile).getOrElse(typeRef.getText)) + .getOrElse(TypeConstants.Any) + val typeFullName = registerType( + nameRenderer.typeFullName(bindingUtils.getVariableDesc(param).get.getType).getOrElse(explicitTypeName) + ) + val localForIt = localNode(decl, "it", "it", typeFullName) + additionalLocals.addOne(Ast(localForIt)) + val identifierForIt = newIdentifierNode("it", typeFullName) + val initAst = Ast(identifierForIt).withRefEdge(identifierForIt, localForIt) + val tmpIdentifier = newIdentifierNode(tmpName, typeFullName) + val local = localNode(decl, tmpName, tmpName, typeFullName) + localForTmp = Some(local) + scope.addToScope(tmpName, local) + val tmpIdentifierAst = Ast(tmpIdentifier).withRefEdge(tmpIdentifier, local) + val assignmentCallNode = + NodeBuilders.newOperatorCallNode(Operators.assignment, s"$tmpName = it", None, line(decl), column(decl)) + callAst(assignmentCallNode, List(tmpIdentifierAst, initAst)) + } + + val localsForDestructuringVars = localsForDestructuringEntries(decl) + val assignmentsForEntries = + decl.getEntries.asScala.filterNot(_.getText == Constants.UnusedDestructuringEntryText).zipWithIndex.map { + case (entry, idx) => + val rhsBaseAst = astWithRefEdgeMaybe( + tmpName, + identifierNode(entry, tmpName, tmpName, localForTmp.map(_.typeFullName).getOrElse(TypeConstants.Any)) + ) + assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1) + } + + localForTmp + .map(l => Ast(l)) + .toSeq ++ additionalLocals ++ localsForDestructuringVars ++ (initCallAst +: assignmentsForEntries) + } + + def astForParameter(param: KtParameter, order: Int): Ast = { val name = if (param.getDestructuringDeclaration != null) { - Constants.paramNameLambdaDestructureDecl + s"${Constants.DestructedParamNamePrefix}${destructedParamKeyPool.next}" } else { param.getName } - val explicitTypeName = Option(param.getTypeReference).map(_.getText).getOrElse(TypeConstants.any) - val typeFullName = registerType(typeInfoProvider.parameterType(param, explicitTypeName)) - val node = parameterInNode(param, name, name, order, false, EvaluationStrategies.BY_VALUE, typeFullName) + val explicitTypeName = Option(param.getTypeReference) + .map(typeRef => + fullNameByImportPath(typeRef, param.getContainingKtFile) + .getOrElse(typeRef.getText) + ) + .getOrElse(TypeConstants.Any) + val typeFullName = registerType( + nameRenderer.typeFullName(bindingUtils.getVariableDesc(param).get.getType).getOrElse(explicitTypeName) + ) + val node = parameterInNode(param, name, name, order, false, EvaluationStrategies.BY_VALUE, typeFullName) scope.addToScope(name, node) val annotations = param.getAnnotationEntries.asScala.map(astForAnnotationEntry).toSeq @@ -170,10 +283,18 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { argIdxMaybe: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val name = nextClosureName() - val (fullName, signature) = typeInfoProvider.fullNameWithSignatureAsLambda(fn, name) - val lambdaMethodNode = methodNode(fn, name, fullName, signature, relativizedPath) + ): Ast = { + val funcDesc = bindingUtils.getFunctionDesc(fn) + val name = nameRenderer.descName(funcDesc) + val descFullName = nameRenderer + .descFullName(funcDesc) + .getOrElse(s"${Defines.UnresolvedNamespace}.$name") + val signature = nameRenderer + .funcDescSignature(funcDesc) + .getOrElse(s"${Defines.UnresolvedSignature}(${fn.getValueParameters.size()})") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + val lambdaMethodNode = methodNode(fn, name, fullName, signature, relativizedPath) val closureBindingEntriesForCaptured = scope .pushClosureScope(lambdaMethodNode) @@ -211,20 +332,20 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { case None => Option(fn.getBodyExpression) .map { expr => - val bodyBlock = blockNode(expr, expr.getText, TypeConstants.any) - val returnAst_ = returnAst(returnNode(expr, Constants.retCode), astsForExpression(expr, Some(1))) + val bodyBlock = blockNode(expr, expr.getText, TypeConstants.Any) + val returnAst_ = returnAst(returnNode(expr, Constants.RetCode), astsForExpression(expr, Some(1))) Seq(blockAst(bodyBlock, localsForCaptured.map(Ast(_)) ++ List(returnAst_))) } .getOrElse { - val bodyBlock = blockNode(fn, "", TypeConstants.any) + val bodyBlock = blockNode(fn, "", TypeConstants.Any) Seq(blockAst(bodyBlock, List[Ast]())) } } - val returnTypeFullName = TypeConstants.javaLangObject + val returnTypeFullName = TypeConstants.JavaLangObject val lambdaTypeDeclFullName = fullName.split(":").head - val bodyAst = bodyAsts.headOption.getOrElse(Ast(unknownNode(fn, Constants.empty))) + val bodyAst = bodyAsts.headOption.getOrElse(Ast(unknownNode(fn, Constants.Empty))) val lambdaMethodAst = methodAst( lambdaMethodNode, parametersAsts, @@ -237,26 +358,26 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { withArgumentIndex(methodRefNode(fn, fn.getText, fullName, lambdaTypeDeclFullName), argIdxMaybe) .argumentName(argNameMaybe) + val samInterface = getSamInterface(fn) + + val baseClassFullName = samInterface.flatMap(nameRenderer.descFullName).getOrElse(Constants.UnknownLambdaBaseClass) + val lambdaTypeDecl = typeDeclNode( fn, - Constants.lambdaTypeDeclName, + Constants.LambdaTypeDeclName, lambdaTypeDeclFullName, relativizedPath, - Seq(registerType(s"${TypeConstants.kotlinFunctionXPrefix}${fn.getValueParameters.size}")), + Seq(registerType(baseClassFullName)), None ) - val lambdaBinding = newBindingNode(Constants.lambdaBindingName, signature, lambdaMethodNode.fullName) - val bindingInfo = BindingInfo( - lambdaBinding, - Seq((lambdaTypeDecl, lambdaBinding, EdgeTypes.BINDS), (lambdaBinding, lambdaMethodNode, EdgeTypes.REF)) - ) + createLambdaBindings(lambdaMethodNode, lambdaTypeDecl, samInterface) + scope.popScope() val closureBindingDefs = closureBindingEntriesForCaptured.collect { case (closureBinding, node) => ClosureBindingDef(closureBinding, _methodRefNode, node.node) } closureBindingDefs.foreach(closureBindingDefQueue.prepend) - lambdaBindingInfoQueue.prepend(bindingInfo) lambdaAstQueue.prepend(lambdaMethodAst) Ast(_methodRefNode) .withChildren(annotations.map(astForAnnotationEntry)) @@ -267,10 +388,18 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { argIdxMaybe: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val name = nextClosureName() - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(expr, name) - val lambdaMethodNode = methodNode(expr, name, fullName, signature, relativizedPath) + ): Ast = { + val funcDesc = bindingUtils.getFunctionDesc(expr.getFunctionLiteral) + val name = nameRenderer.descName(funcDesc) + val descFullName = nameRenderer + .descFullName(funcDesc) + .getOrElse(s"${Defines.UnresolvedNamespace}.$name") + val signature = nameRenderer + .funcDescSignature(funcDesc) + .getOrElse(s"${Defines.UnresolvedSignature}(${expr.getFunctionLiteral.getValueParameters.size()})") + val fullName = nameRenderer.combineFunctionFullName(descFullName, signature) + + val lambdaMethodNode = methodNode(expr, name, fullName, signature, relativizedPath) val closureBindingEntriesForCaptured = scope .pushClosureScope(lambdaMethodNode) @@ -298,43 +427,36 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { scope.addToScope(capturedNodeContext.name, node) node } - val parametersAsts = typeInfoProvider.implicitParameterName(expr) match { - case Some(implicitParamName) => - val node = parameterInNode( - expr, - implicitParamName, - implicitParamName, - 1, - false, - EvaluationStrategies.BY_REFERENCE, - TypeConstants.any - ) - scope.addToScope(implicitParamName, node) - Seq(Ast(node)) - case None => - withIndex(expr.getValueParameters.asScala.toSeq) { (p, idx) => - val destructuringEntries = - Option(p.getDestructuringDeclaration) - .map(_.getEntries.asScala) - .getOrElse(Seq()) - if (destructuringEntries.nonEmpty) - destructuringEntries.filterNot(_.getText == Constants.unusedDestructuringEntryText).zipWithIndex.map { - case (entry, innerIdx) => - val name = entry.getName - val explicitTypeName = Option(entry.getTypeReference).map(_.getText).getOrElse(TypeConstants.any) - val typeFullName = registerType(typeInfoProvider.destructuringEntryType(entry, explicitTypeName)) - val node = - parameterInNode(entry, name, name, innerIdx + idx, false, EvaluationStrategies.BY_VALUE, typeFullName) - scope.addToScope(name, node) - Ast(node) - } - else Seq(astForParameter(p, idx)) - }.flatten + + val paramAsts = mutable.ArrayBuffer.empty[Ast] + val destructedParamAsts = mutable.ArrayBuffer.empty[Ast] + val valueParamStartIndex = + if (funcDesc.getExtensionReceiverParameter != null) { + // Lambdas which are arguments to function parameters defined + // like `func: extendedType.(argTypes) -> returnType` have an implicit extension receiver parameter + // which can be accessed as `this` + paramAsts.append(createImplicitParamNode(expr, funcDesc.getExtensionReceiverParameter, "this", 1)) + 2 + } else { + 1 + } + + funcDesc.getValueParameters.asScala match { + case parameters if parameters.size == 1 && !parameters.head.getSource.isInstanceOf[KotlinSourceElement] => + // Here we handle the implicit `it` parameter. + paramAsts.append(createImplicitParamNode(expr, parameters.head, "it", valueParamStartIndex)) + case parameters => + parameters.zipWithIndex.foreach { (paramDesc, idx) => + val param = paramDesc.getSource.asInstanceOf[KotlinSourceElement].getPsi.asInstanceOf[KtParameter] + paramAsts.append(astForParameter(param, valueParamStartIndex + idx)) + if (param.getDestructuringDeclaration != null) { + destructedParamAsts.appendAll(astsForDestructuring(param)) + } + } } val lastChildNotReturnExpression = !expr.getBodyExpression.getLastChild.isInstanceOf[KtReturnExpression] - val needsReturnExpression = - lastChildNotReturnExpression && !typeInfoProvider.hasApplyOrAlsoScopeFunctionParent(expr) + val needsReturnExpression = lastChildNotReturnExpression val bodyAsts = Option(expr.getBodyExpression) .map( astsForBlock( @@ -343,12 +465,15 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { None, pushToScope = false, localsForCaptured, - implicitReturnAroundLastStatement = needsReturnExpression + implicitReturnAroundLastStatement = needsReturnExpression, + Some(destructedParamAsts.toSeq) ) ) .getOrElse(Seq(Ast(NewBlock()))) - val returnTypeFullName = registerType(typeInfoProvider.returnTypeFullName(expr)) + val returnTypeFullName = registerType( + nameRenderer.typeFullName(funcDesc.getReturnType).getOrElse(TypeConstants.JavaLangObject) + ) val lambdaTypeDeclFullName = fullName.split(":").head val (bodyAst, nestedLambdaDecls) = bodyAsts.toList match { @@ -356,11 +481,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { if (nestedLambdaDecls.exists(_.root.exists(x => !x.isInstanceOf[NewMethod]))) logger.warn("Detected non-method related AST nodes under lambda expression. This is unexpected.") body -> nestedLambdaDecls - case Nil => Ast(unknownNode(expr, Constants.empty)) -> Nil + case Nil => Ast(unknownNode(expr, Constants.Empty)) -> Nil } val lambdaMethodAst = methodAst( lambdaMethodNode, - parametersAsts, + paramAsts.toSeq, bodyAst, newMethodReturnNode(returnTypeFullName, None, line(expr), column(expr)), newModifierNode(ModifierTypes.VIRTUAL) :: newModifierNode(ModifierTypes.LAMBDA) :: Nil @@ -370,33 +495,162 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { withArgumentIndex(methodRefNode(expr, expr.getText, fullName, lambdaTypeDeclFullName), argIdxMaybe) .argumentName(argNameMaybe) + val samInterface = getSamInterface(expr) + + val baseClassFullName = samInterface.flatMap(nameRenderer.descFullName).getOrElse(Constants.UnknownLambdaBaseClass) + val lambdaTypeDecl = typeDeclNode( expr, - Constants.lambdaTypeDeclName, + Constants.LambdaTypeDeclName, lambdaTypeDeclFullName, relativizedPath, - Seq(registerType(s"${TypeConstants.kotlinFunctionXPrefix}${expr.getValueParameters.size}")), + Seq(registerType(baseClassFullName)), None ) - val lambdaBinding = newBindingNode(Constants.lambdaBindingName, signature, lambdaMethodNode.fullName) - val bindingInfo = BindingInfo( - lambdaBinding, - Seq((lambdaTypeDecl, lambdaBinding, EdgeTypes.BINDS), (lambdaBinding, lambdaMethodNode, EdgeTypes.REF)) - ) + createLambdaBindings(lambdaMethodNode, lambdaTypeDecl, samInterface) + scope.popScope() val closureBindingDefs = closureBindingEntriesForCaptured.collect { case (closureBinding, node) => ClosureBindingDef(closureBinding, _methodRefNode, node.node) } closureBindingDefs.foreach(closureBindingDefQueue.prepend) - lambdaBindingInfoQueue.prepend(bindingInfo) lambdaAstQueue.prepend(lambdaMethodAst) nestedLambdaDecls.foreach(lambdaAstQueue.prepend) Ast(_methodRefNode) .withChildren(annotations.map(astForAnnotationEntry)) } - def astForReturnExpression(expr: KtReturnExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def createImplicitParamNode( + expr: KtLambdaExpression, + paramDesc: ParameterDescriptor, + paramName: String, + index: Int + ): Ast = { + val node = parameterInNode( + expr, + paramName, + paramName, + index, + false, + EvaluationStrategies.BY_REFERENCE, + nameRenderer.typeFullName(paramDesc.getType).getOrElse(TypeConstants.Any) + ) + scope.addToScope(paramName, node) + Ast(node) + } + + // SAM stands for: single abstraction method + private def getSamInterface(expr: KtLambdaExpression | KtNamedFunction): Option[ClassDescriptor] = { + getSurroundingCallTarget(expr) match { + case Some(callTarget) => + val resolvedCallAtom = bindingUtils + .getResolvedCallDesc(callTarget) + .collect { case call: NewAbstractResolvedCall[?] => + call.getResolvedCallAtom + } + + resolvedCallAtom.map { callAtom => + callAtom.getCandidateDescriptor match { + case samConstructorDesc: SamConstructorDescriptor => + // Lambda is wrapped e.g. `SomeInterface { obj -> obj }` + samConstructorDesc.getBaseDescriptorForSynthetic + case _ => + // Lambda/anon function is directly used as call argument e.g. `someCall(obj -> obj)` + val directCallArgumentForLookup = + expr match { + case _: KtNamedFunction => + // This is the anonymous function case. + // So far it does not seem like those could be wrapped so they are always the direct argument. + expr + case _ => + // The lambda function case. + getDirectLambdaArgument(expr).get + } + callAtom.getArgumentMappingByOriginal.asScala.collectFirst { + case (paramDesc, resolvedArgument) if isExprIncluded(resolvedArgument, directCallArgumentForLookup) => + paramDesc.getType.getConstructor.getDeclarationDescriptor.asInstanceOf[ClassDescriptor] + }.get + } + } + case None => + // Lambda/anon function is directly assigned to a variable. + // E.g. `val l = { i: Int -> i }` + val lambdaExprType = bindingUtils.getExprType(expr) + lambdaExprType.map(_.getConstructor.getDeclarationDescriptor.asInstanceOf[ClassDescriptor]) + } + } + + private def getDirectLambdaArgument(element: KtElement): Option[KtExpression] = { + var context: PsiElement = element + var parentContext: PsiElement = null + + // KtCallExpression wrap their arguments in KtValueArgument which is why we look for those. + // KtBinaryExpressions do not do such a wrapping. + while ({ + parentContext = context.getContext + parentContext != null && + !parentContext.isInstanceOf[KtValueArgument] && + !parentContext.isInstanceOf[KtBinaryExpression] + }) { + context = parentContext + } + + if (parentContext != null) { + Some(context.asInstanceOf[KtExpression]) + } else { + None + } + } + + private def getSurroundingCallTarget(element: KtElement): Option[KtExpression] = { + var context: PsiElement = element.getContext + while ( + context != null && + !context.isInstanceOf[KtCallExpression] && + !context.isInstanceOf[KtBinaryExpression] + ) { + context = context.getContext + } + context match { + case callExpr: KtCallExpression => + Some(callExpr.getCalleeExpression) + case binaryExpr: KtBinaryExpression => + Some(binaryExpr.getOperationReference) + case null => + None + } + } + + private def isExprIncluded(resolvedArgument: ResolvedCallArgument, expr: KtExpression): Boolean = { + // getArguments returns multiple arguments in case of varargs + resolvedArgument.getArguments.asScala.exists { + case psi: PSIFunctionKotlinCallArgument => + psi.getExpression == expr + case _ => + false + } + } + + private def createLambdaBindings( + lambdaMethodNode: NewMethod, + lambdaTypeDecl: NewTypeDecl, + samInterface: Option[ClassDescriptor] + ): Unit = { + val samMethod = samInterface.map(SamConversionResolverImplKt.getSingleAbstractMethodOrNull) + val samMethodName = samMethod.map(_.getName.toString).getOrElse(Constants.UnknownLambdaBindingName) + val samMethodSignature = samMethod.flatMap(nameRenderer.funcDescSignature) + + if (samMethodSignature.isDefined) { + val interfaceLambdaBinding = newBindingNode(samMethodName, samMethodSignature.get, lambdaMethodNode.fullName) + addToLambdaBindingInfoQueue(interfaceLambdaBinding, lambdaTypeDecl, lambdaMethodNode) + } + + val nativeLambdaBinding = newBindingNode(samMethodName, lambdaMethodNode.signature, lambdaMethodNode.fullName) + addToLambdaBindingInfoQueue(nativeLambdaBinding, lambdaTypeDecl, lambdaMethodNode) + } + + def astForReturnExpression(expr: KtReturnExpression): Ast = { val returnedExpr = if (expr.getReturnedExpression != null) { astsForExpression(expr.getReturnedExpression, None) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForPrimitivesCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForPrimitivesCreator.scala index e6d85bb26cb9..68a4c2738c60 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForPrimitivesCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForPrimitivesCreator.scala @@ -2,7 +2,6 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.types.TypeConstants -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.Ast import io.joern.x2cpg.Defines import io.joern.x2cpg.ValidationMode @@ -17,6 +16,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewMember import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal +import org.jetbrains.kotlin.descriptors.ClassifierDescriptor +import org.jetbrains.kotlin.descriptors.PropertyDescriptor +import org.jetbrains.kotlin.descriptors.ValueDescriptor import org.jetbrains.kotlin.psi.KtAnnotationEntry import org.jetbrains.kotlin.psi.KtClassLiteralExpression import org.jetbrains.kotlin.psi.KtConstantExpression @@ -27,6 +29,7 @@ import org.jetbrains.kotlin.psi.KtSuperExpression import org.jetbrains.kotlin.psi.KtThisExpression import org.jetbrains.kotlin.psi.KtTypeAlias import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.types.error.ErrorType import scala.annotation.unused import scala.jdk.CollectionConverters.* @@ -40,8 +43,8 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = literalNode(expr, expr.getText, typeFullName) val annotationAsts = annotations.map(astForAnnotationEntry) Ast(withArgumentName(withArgumentIndex(node, argIdx), argName)).withChildren(annotationAsts) @@ -52,12 +55,12 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val outAst = if (expr.hasInterpolation) { val args = expr.getEntries.filter(_.getExpression != null).zipWithIndex.map { case (entry, idx) => - val entryTypeFullName = registerType(typeInfoProvider.expressionType(entry.getExpression, TypeConstants.any)) + val entryTypeFullName = registerType(exprTypeFullName(entry.getExpression).getOrElse(TypeConstants.Any)) val valueCallNode = NodeBuilders.newOperatorCallNode( Operators.formattedValue, entry.getExpression.getText, @@ -89,28 +92,29 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val isReferencingMember = scope.lookupVariable(expr.getIdentifier.getText) match { case Some(_: NewMember) => true case _ => false } + val isUsedAsImplicitThis = typeInfoProvider.usedAsImplicitThis(expr) val outAst = if (typeInfoProvider.isReferenceToClass(expr)) astForNameReferenceToType(expr, argIdx) - else if (isReferencingMember) astForNameReferenceToMember(expr, argIdx) + else if (isReferencingMember || isUsedAsImplicitThis) astForNameReferenceToMember(expr, argIdx) else astForNonSpecialNameReference(expr, argIdx, argName) outAst.withChildren(annotations.map(astForAnnotationEntry)) } - private def astForNameReferenceToType(expr: KtNameReferenceExpression, argIdx: Option[Int])(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { - val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any)) + private def astForNameReferenceToType(expr: KtNameReferenceExpression, argIdx: Option[Int]): Ast = { + val declDesc = + bindingUtils.getDeclDesc(expr).collect { case classifierDesc: ClassifierDescriptor => classifierDesc } + val typeFullName = registerType(declDesc.flatMap(nameRenderer.descFullName).getOrElse(TypeConstants.Any)) val referencesCompanionObject = typeInfoProvider.isRefToCompanionObject(expr) if (referencesCompanionObject) { val argAsts = List( // TODO: change this to a TYPE_REF node as soon as the closed source data-flow engine supports it identifierNode(expr, expr.getIdentifier.getText, expr.getIdentifier.getText, typeFullName), - fieldIdentifierNode(expr, Constants.companionObjectMemberName, Constants.companionObjectMemberName) + fieldIdentifierNode(expr, Constants.CompanionObjectMemberName, Constants.CompanionObjectMemberName) ).map(Ast(_)) val node = NodeBuilders.newOperatorCallNode( Operators.fieldAccess, @@ -126,19 +130,24 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { } } - private def astForNameReferenceToMember(expr: KtNameReferenceExpression, argIdx: Option[Int])(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { - val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any)) - val referenceTargetTypeFullName = registerType( - typeInfoProvider.referenceTargetTypeFullName(expr, TypeConstants.any) - ) - val thisNode = identifierNode(expr, Constants.this_, Constants.this_, referenceTargetTypeFullName) - val thisAst = astWithRefEdgeMaybe(Constants.this_, thisNode) + private def astForNameReferenceToMember(expr: KtNameReferenceExpression, argIdx: Option[Int]): Ast = { + val declDesc = bindingUtils.getDeclDesc(expr).collect { case propDesc: PropertyDescriptor => propDesc } + val typeFullName = declDesc + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) + registerType(typeFullName) + + val baseTypeFullName = declDesc + .flatMap(desc => nameRenderer.typeFullName(desc.getDispatchReceiverParameter.getType)) + .getOrElse(TypeConstants.Any) + registerType(baseTypeFullName) + + val thisNode = identifierNode(expr, Constants.ThisName, Constants.ThisName, baseTypeFullName) + val thisAst = astWithRefEdgeMaybe(Constants.ThisName, thisNode) val _fieldIdentifierNode = fieldIdentifierNode(expr, expr.getReferencedName, expr.getReferencedName) val node = NodeBuilders.newOperatorCallNode( Operators.fieldAccess, - s"${Constants.this_}.${expr.getReferencedName}", + s"${Constants.ThisName}.${expr.getReferencedName}", Option(typeFullName), line(expr), column(expr) @@ -150,18 +159,20 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { expr: KtNameReferenceExpression, argIdx: Option[Int], argName: Option[String] = None - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFromScopeMaybe = scope.lookupVariable(expr.getIdentifier.getText) match { - case Some(n: NewLocal) => Some(n.typeFullName) - case Some(n: NewMethodParameterIn) => Some(n.typeFullName) - case _ => None - } - val typeFromProvider = typeInfoProvider.typeFullName(expr, Defines.UnresolvedNamespace) - val typeFullName = - typeFromScopeMaybe match { - case Some(fullName) => registerType(fullName) - case None => registerType(typeFromProvider) + ): Ast = { + val declDesc = bindingUtils.getDeclDesc(expr).collect { case valueDesc: ValueDescriptor => valueDesc } + val typeFullName = declDesc + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .orElse { + val typeFromScopeMaybe = scope.lookupVariable(expr.getIdentifier.getText) match { + case Some(n: NewLocal) => Some(n.typeFullName) + case Some(n: NewMethodParameterIn) => Some(n.typeFullName) + case _ => None + } + typeFromScopeMaybe } + .getOrElse(TypeConstants.Any) + val name = expr.getIdentifier.getText val node = withArgumentName(withArgumentIndex(identifierNode(expr, name, name, typeFullName), argIdx), argName) @@ -173,8 +184,8 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = withArgumentName( withArgumentIndex(identifierNode(expr, expr.getText, expr.getText, typeFullName), argIdx), argName @@ -188,8 +199,8 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Ast = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = withArgumentName( withArgumentIndex(identifierNode(expr, expr.getText, expr.getText, typeFullName), argIdx), argName @@ -203,13 +214,14 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argName: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val (fullName, signature) = typeInfoProvider.fullNameWithSignature(expr, ("", "")) // TODO: fix the fallback names - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.javaLangObject)) + ): Ast = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.JavaLangObject)) + val fullName = ".class" + val signature = s"$typeFullName()" val node = callNode( expr, expr.getText, - TypeConstants.classLiteralReplacementMethodName, + fullName, fullName, DispatchTypes.STATIC_DISPATCH, Some(signature), @@ -221,14 +233,14 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { def astForImportDirective(directive: KtImportDirective): Ast = { val importedAs = Try(directive.getImportedName.getIdentifier).toOption - val isWildcard = importedAs.contains(Constants.wildcardImportName) || directive.getImportedName == null + val isWildcard = importedAs.contains(Constants.WildcardImportName) || directive.getImportedName == null val node = NewImport() .isWildcard(isWildcard) .isExplicit(true) .importedAs(importedAs) .importedEntity(directive.getImportPath.getPathStr) - .code(s"${Constants.importKeyword} ${directive.getImportPath.getPathStr}") + .code(s"${Constants.ImportKeyword} ${directive.getImportPath.getPathStr}") .lineNumber(line(directive)) .columnNumber(column(directive)) Ast(node) @@ -237,7 +249,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { @unused def astForPackageDeclaration(packageName: String): Ast = { val node = - if (packageName == Constants.root) + if (packageName == Constants.Root) NodeBuilders.newNamespaceBlockNode( NamespaceTraversal.globalNamespaceName, NamespaceTraversal.globalNamespaceName, @@ -250,14 +262,13 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { Ast(node) } - def astForAnnotationEntry(entry: KtAnnotationEntry)(implicit typeInfoProvider: TypeInfoProvider): Ast = { - val typeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any) match { - case value if value != TypeConstants.any => value - case _ => - typeInfoProvider - .typeFromImports(entry.getShortName.toString, entry.getContainingKtFile) - .getOrElse(TypeConstants.any) - }) + def astForAnnotationEntry(entry: KtAnnotationEntry): Ast = { + val typeFullName = nameRenderer + .typeFullName(bindingUtils.getAnnotationDesc(entry).getType) + .orElse(fullNameByImportPath(entry.getTypeReference, entry.getContainingKtFile)) + .getOrElse(s"${Defines.UnresolvedNamespace}.${entry.getShortName.toString}") + registerType(typeFullName) + val node = NewAnnotation() .code(entry.getText) @@ -281,23 +292,31 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { annotationAst(node, children) } - def astForTypeAlias(typeAlias: KtTypeAlias)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + def astForTypeAlias(typeAlias: KtTypeAlias): Ast = { + val typeAliasDesc = bindingUtils.getTypeAliasDesc(typeAlias) + val aliasedType = typeAliasDesc.getExpandedType match { + case _: ErrorType => + None + case nonErrorType => + Some(nonErrorType) + } + val node = typeDeclNode( typeAlias, typeAlias.getName, - registerType(typeInfoProvider.fullName(typeAlias, TypeConstants.any)), + registerType(nameRenderer.descFullName(typeAliasDesc).getOrElse(TypeConstants.Any)), relativizedPath, Seq(), - Option(registerType(typeInfoProvider.aliasTypeFullName(typeAlias, TypeConstants.any))) + Option(registerType(aliasedType.flatMap(nameRenderer.typeFullName).getOrElse(TypeConstants.Any))) ) Ast(node) } - def astForTypeReference(expr: KtTypeReference, argIdx: Option[Int], argName: Option[String])(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { - val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any)) - val node = typeRefNode(expr, expr.getText, typeFullName) + def astForTypeReference(expr: KtTypeReference, argIdx: Option[Int], argName: Option[String]): Ast = { + val typeFullName = registerType( + bindingUtils.getTypeRefType(expr).flatMap(nameRenderer.typeFullName).getOrElse(TypeConstants.Any) + ) + val node = typeRefNode(expr, expr.getText, typeFullName) Ast(withArgumentName(withArgumentIndex(node, argIdx), argName)) } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala index 65ce497f60a9..a28ad5e2a99e 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala @@ -2,7 +2,6 @@ package io.joern.kotlin2cpg.ast import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.types.TypeConstants -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.Ast import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.utils.NodeBuilders @@ -35,9 +34,7 @@ import scala.jdk.CollectionConverters.* trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - def astForFor(expr: KtForExpression, annotations: Seq[KtAnnotationEntry] = Seq())(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { + def astForFor(expr: KtForExpression, annotations: Seq[KtAnnotationEntry] = Seq()): Ast = { val outAst = if (expr.getDestructuringDeclaration != null) astForForWithDestructuringLHS(expr) else astForForWithSimpleVarLHS(expr) @@ -60,25 +57,25 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { // |-> loweringOf{d2 = tmp.component2()} // |-> // - private def astForForWithDestructuringLHS(expr: KtForExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astForForWithDestructuringLHS(expr: KtForExpression): Ast = { val loopRangeText = expr.getLoopRange.getText - val iteratorName = s"${Constants.iteratorPrefix}${iteratorKeyPool.next()}" - val localForIterator = localNode(expr, iteratorName, iteratorName, TypeConstants.any) - val iteratorAssignmentLhs = newIdentifierNode(iteratorName, TypeConstants.any) + val iteratorName = s"${Constants.IteratorPrefix}${iteratorKeyPool.next}" + val localForIterator = localNode(expr, iteratorName, iteratorName, TypeConstants.Any) + val iteratorAssignmentLhs = newIdentifierNode(iteratorName, TypeConstants.Any) val iteratorLocalAst = Ast(localForIterator).withRefEdge(iteratorAssignmentLhs, localForIterator) // TODO: maybe use a different method here, one which does not translate `kotlin.collections.List` to `java.util.List` - val loopRangeExprTypeFullName = registerType(typeInfoProvider.expressionType(expr.getLoopRange, TypeConstants.any)) + val loopRangeExprTypeFullName = registerType(exprTypeFullName(expr.getLoopRange).getOrElse(TypeConstants.Any)) val iteratorAssignmentRhsIdentifier = newIdentifierNode(loopRangeText, loopRangeExprTypeFullName) .argumentIndex(0) val iteratorAssignmentRhs = callNode( expr.getLoopRange, - s"$loopRangeText.${Constants.getIteratorMethodName}()", - Constants.getIteratorMethodName, - s"$loopRangeExprTypeFullName.${Constants.getIteratorMethodName}:${Constants.javaUtilIterator}()", + s"$loopRangeText.${Constants.GetIteratorMethodName}()", + Constants.GetIteratorMethodName, + s"$loopRangeExprTypeFullName.${Constants.GetIteratorMethodName}:${Constants.JavaUtilIterator}()", DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${Constants.javaUtilIterator}()"), - Some(Constants.javaUtilIterator) + Some(s"${Constants.JavaUtilIterator}()"), + Some(Constants.JavaUtilIterator) ) val iteratorAssignmentRhsAst = @@ -92,50 +89,40 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val conditionIdentifier = newIdentifierNode(loopRangeText, loopRangeExprTypeFullName).argumentIndex(0) val hasNextFullName = - s"${Constants.collectionsIteratorName}.${Constants.hasNextIteratorMethodName}:${TypeConstants.javaLangBoolean}()" + s"${Constants.CollectionsIteratorName}.${Constants.HasNextIteratorMethodName}:${TypeConstants.JavaLangBoolean}()" val controlStructureCondition = callNode( expr.getLoopRange, - s"$iteratorName.${Constants.hasNextIteratorMethodName}()", - Constants.hasNextIteratorMethodName, + s"$iteratorName.${Constants.HasNextIteratorMethodName}()", + Constants.HasNextIteratorMethodName, hasNextFullName, DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${TypeConstants.javaLangBoolean}()"), - Some(TypeConstants.javaLangBoolean) + Some(s"${TypeConstants.JavaLangBoolean}()"), + Some(TypeConstants.JavaLangBoolean) ).argumentIndex(0) val controlStructureConditionAst = callAst(controlStructureCondition, List(), Option(Ast(conditionIdentifier))) - val destructuringDeclEntries = expr.getDestructuringDeclaration.getEntries - val localsForDestructuringVars = - destructuringDeclEntries.asScala - .filterNot(_.getText == Constants.unusedDestructuringEntryText) - .map { entry => - val entryTypeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any)) - val entryName = entry.getText - val node = localNode(entry, entryName, entryName, entryTypeFullName) - scope.addToScope(entryName, node) - Ast(node) - } - .toList + val destructuringDeclEntries = expr.getDestructuringDeclaration.getEntries + val localsForDestructuringVars = localsForDestructuringEntries(expr.getDestructuringDeclaration) - val tmpName = s"${Constants.tmpLocalPrefix}${tmpKeyPool.next}" - val localForTmp = localNode(expr, tmpName, tmpName, TypeConstants.any) + val tmpName = s"${Constants.TmpLocalPrefix}${tmpKeyPool.next}" + val localForTmp = localNode(expr, tmpName, tmpName, TypeConstants.Any) scope.addToScope(localForTmp.name, localForTmp) val localForTmpAst = Ast(localForTmp) - val tmpIdentifier = newIdentifierNode(tmpName, TypeConstants.any) + val tmpIdentifier = newIdentifierNode(tmpName, TypeConstants.Any) val tmpIdentifierAst = Ast(tmpIdentifier).withRefEdge(tmpIdentifier, localForTmp) - val iteratorNextIdentifier = newIdentifierNode(iteratorName, TypeConstants.any).argumentIndex(0) + val iteratorNextIdentifier = newIdentifierNode(iteratorName, TypeConstants.Any).argumentIndex(0) val iteratorNextIdentifierAst = Ast(iteratorNextIdentifier).withRefEdge(iteratorNextIdentifier, localForIterator) val iteratorNextCall = callNode( expr.getLoopRange, - s"${iteratorNextIdentifier.code}.${Constants.nextIteratorMethodName}()", - Constants.nextIteratorMethodName, - s"${Constants.collectionsIteratorName}.${Constants.nextIteratorMethodName}:${TypeConstants.javaLangObject}()", + s"${iteratorNextIdentifier.code}.${Constants.NextIteratorMethodName}()", + Constants.NextIteratorMethodName, + s"${Constants.CollectionsIteratorName}.${Constants.NextIteratorMethodName}:${TypeConstants.JavaLangObject}()", DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${TypeConstants.javaLangObject}()"), - Some(TypeConstants.javaLangObject) + Some(s"${TypeConstants.JavaLangObject}()"), + Some(TypeConstants.JavaLangObject) ) val iteratorNextCallAst = @@ -145,16 +132,22 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val tmpParameterNextAssignmentAst = callAst(tmpParameterNextAssignment, List(tmpIdentifierAst, iteratorNextCallAst)) val assignmentsForEntries = - destructuringDeclEntries.asScala.filterNot(_.getText == Constants.unusedDestructuringEntryText).zipWithIndex.map { + destructuringDeclEntries.asScala.filterNot(_.getText == Constants.UnusedDestructuringEntryText).zipWithIndex.map { case (entry, idx) => - assignmentAstForDestructuringEntry(entry, localForTmp.name, localForTmp.typeFullName, idx + 1) + val rhsBaseAst = + astWithRefEdgeMaybe( + localForTmp.name, + identifierNode(entry, localForTmp.name, localForTmp.name, localForTmp.typeFullName) + .argumentIndex(0) + ) + assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1) } val stmtAsts = astsForExpression(expr.getBody, None) val controlStructureBody = blockNode(expr.getBody, "", "") val controlStructureBodyAst = blockAst( controlStructureBody, - localsForDestructuringVars ++ + localsForDestructuringVars.toList ++ List(localForTmpAst, tmpParameterNextAssignmentAst) ++ assignmentsForEntries ++ stmtAsts @@ -163,7 +156,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val _controlStructureAst = controlStructureAst(controlStructure, Some(controlStructureConditionAst), Seq(controlStructureBodyAst)) blockAst( - blockNode(expr, Constants.codeForLoweredForBlock, ""), + blockNode(expr, Constants.CodeForLoweredForBlock, ""), List(iteratorLocalAst, iteratorAssignmentAst, _controlStructureAst) ) } @@ -180,25 +173,25 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { // |-> loweringOf{one = iterator.next()} // |-> // - private def astForForWithSimpleVarLHS(expr: KtForExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astForForWithSimpleVarLHS(expr: KtForExpression): Ast = { val loopRangeText = expr.getLoopRange.getText - val iteratorName = s"${Constants.iteratorPrefix}${iteratorKeyPool.next()}" - val iteratorLocal = localNode(expr, iteratorName, iteratorName, TypeConstants.any) - val iteratorAssignmentLhs = newIdentifierNode(iteratorName, TypeConstants.any) + val iteratorName = s"${Constants.IteratorPrefix}${iteratorKeyPool.next}" + val iteratorLocal = localNode(expr, iteratorName, iteratorName, TypeConstants.Any) + val iteratorAssignmentLhs = newIdentifierNode(iteratorName, TypeConstants.Any) val iteratorLocalAst = Ast(iteratorLocal).withRefEdge(iteratorAssignmentLhs, iteratorLocal) - val loopRangeExprTypeFullName = registerType(typeInfoProvider.expressionType(expr.getLoopRange, TypeConstants.any)) + val loopRangeExprTypeFullName = registerType(exprTypeFullName(expr.getLoopRange).getOrElse(TypeConstants.Any)) val iteratorAssignmentRhsIdentifier = newIdentifierNode(loopRangeText, loopRangeExprTypeFullName) .argumentIndex(0) val iteratorAssignmentRhs = callNode( expr.getLoopRange, - s"$loopRangeText.${Constants.getIteratorMethodName}()", - Constants.getIteratorMethodName, - s"$loopRangeExprTypeFullName.${Constants.getIteratorMethodName}:${Constants.javaUtilIterator}()", + s"$loopRangeText.${Constants.GetIteratorMethodName}()", + Constants.GetIteratorMethodName, + s"$loopRangeExprTypeFullName.${Constants.GetIteratorMethodName}:${Constants.JavaUtilIterator}()", DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${Constants.javaUtilIterator}()"), - Some(Constants.javaUtilIterator) + Some(s"${Constants.JavaUtilIterator}()"), + Some(Constants.JavaUtilIterator) ) val iteratorAssignmentRhsAst = @@ -212,40 +205,43 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val conditionIdentifier = newIdentifierNode(loopRangeText, loopRangeExprTypeFullName).argumentIndex(0) val hasNextFullName = - s"${Constants.collectionsIteratorName}.${Constants.hasNextIteratorMethodName}:${TypeConstants.javaLangBoolean}()" + s"${Constants.CollectionsIteratorName}.${Constants.HasNextIteratorMethodName}:${TypeConstants.JavaLangBoolean}()" val controlStructureCondition = callNode( expr.getLoopRange, - s"$iteratorName.${Constants.hasNextIteratorMethodName}()", - Constants.hasNextIteratorMethodName, + s"$iteratorName.${Constants.HasNextIteratorMethodName}()", + Constants.HasNextIteratorMethodName, hasNextFullName, DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${TypeConstants.javaLangBoolean}()"), - Some(TypeConstants.javaLangBoolean) + Some(s"${TypeConstants.JavaLangBoolean}()"), + Some(TypeConstants.JavaLangBoolean) ).argumentIndex(0) val controlStructureConditionAst = callAst(controlStructureCondition, List(), Option(Ast(conditionIdentifier))) val loopParameterTypeFullName = registerType( - typeInfoProvider.typeFullName(expr.getLoopParameter, TypeConstants.any) + bindingUtils + .getVariableDesc(expr.getLoopParameter) + .flatMap(desc => nameRenderer.typeFullName(desc.getType)) + .getOrElse(TypeConstants.Any) ) val loopParameterName = expr.getLoopParameter.getText val loopParameterLocal = localNode(expr, loopParameterName, loopParameterName, loopParameterTypeFullName) scope.addToScope(loopParameterName, loopParameterLocal) - val loopParameterIdentifier = newIdentifierNode(loopParameterName, TypeConstants.any) + val loopParameterIdentifier = newIdentifierNode(loopParameterName, TypeConstants.Any) val loopParameterAst = Ast(loopParameterLocal).withRefEdge(loopParameterIdentifier, loopParameterLocal) - val iteratorNextIdentifier = newIdentifierNode(iteratorName, TypeConstants.any).argumentIndex(0) + val iteratorNextIdentifier = newIdentifierNode(iteratorName, TypeConstants.Any).argumentIndex(0) val iteratorNextIdentifierAst = Ast(iteratorNextIdentifier).withRefEdge(iteratorNextIdentifier, iteratorLocal) val iteratorNextCall = callNode( expr.getLoopParameter, - s"$iteratorName.${Constants.nextIteratorMethodName}()", - Constants.nextIteratorMethodName, - s"${Constants.collectionsIteratorName}.${Constants.nextIteratorMethodName}:${TypeConstants.javaLangObject}()", + s"$iteratorName.${Constants.NextIteratorMethodName}()", + Constants.NextIteratorMethodName, + s"${Constants.CollectionsIteratorName}.${Constants.NextIteratorMethodName}:${TypeConstants.JavaLangObject}()", DispatchTypes.DYNAMIC_DISPATCH, - Some(s"${TypeConstants.javaLangObject}()"), - Some(TypeConstants.javaLangObject) + Some(s"${TypeConstants.JavaLangObject}()"), + Some(TypeConstants.JavaLangObject) ) val iteratorNextCallAst = callAst(iteratorNextCall, Seq(), Option(iteratorNextIdentifierAst)) @@ -262,7 +258,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val _controlStructureAst = controlStructureAst(controlStructure, Some(controlStructureConditionAst), Seq(controlStructureBodyAst)) blockAst( - blockNode(expr, Constants.codeForLoweredForBlock, ""), + blockNode(expr, Constants.CodeForLoweredForBlock, ""), List(iteratorLocalAst, iteratorAssignmentAst, _controlStructureAst) ) } @@ -272,15 +268,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val isChildOfControlStructureBody = expr.getParent.isInstanceOf[KtContainerNodeForControlStructureBody] if (KtPsiUtil.isStatement(expr) && !isChildOfControlStructureBody) astForIfAsControlStructure(expr, annotations) else astForIfAsExpression(expr, argIdx, argNameMaybe, annotations) } - private def astForIfAsControlStructure(expr: KtIfExpression, annotations: Seq[KtAnnotationEntry] = Seq())(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { + private def astForIfAsControlStructure(expr: KtIfExpression, annotations: Seq[KtAnnotationEntry] = Seq()): Ast = { val conditionAst = astsForExpression(expr.getCondition, None).headOption val thenAsts = astsForExpression(expr.getThen, None) val elseAsts = Option(expr.getElse).toSeq.flatMap(astsForExpression(_, None)) @@ -295,14 +289,14 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val conditionAsts = astsForExpression(expr.getCondition, None) val thenAsts = astsForExpression(expr.getThen, None) val elseAsts = Option(expr.getElse).toSeq.flatMap(astsForExpression(_, None)) val allAsts = (conditionAsts ++ thenAsts ++ elseAsts).toList if (allAsts.nonEmpty) { - val returnTypeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val returnTypeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = NodeBuilders.newOperatorCallNode( Operators.conditional, @@ -319,9 +313,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { } } - def astForWhile(expr: KtWhileExpression, annotations: Seq[KtAnnotationEntry] = Seq())(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { + def astForWhile(expr: KtWhileExpression, annotations: Seq[KtAnnotationEntry] = Seq()): Ast = { val conditionAst = astsForExpression(expr.getCondition, None).headOption val stmtAsts = astsForExpression(expr.getBody, None) val code = Option(expr.getText) @@ -332,9 +324,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { .withChildren(annotations.map(astForAnnotationEntry)) } - def astForDoWhile(expr: KtDoWhileExpression, annotations: Seq[KtAnnotationEntry] = Seq())(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { + def astForDoWhile(expr: KtDoWhileExpression, annotations: Seq[KtAnnotationEntry] = Seq()): Ast = { val conditionAst = astsForExpression(expr.getCondition, None).headOption val stmtAsts = astsForExpression(expr.getBody, None) val code = Option(expr.getText) @@ -345,9 +335,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { .withChildren(annotations.map(astForAnnotationEntry)) } - private def astForWhenAsStatement(expr: KtWhenExpression, argIdx: Option[Int])(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { + private def astForWhenAsStatement(expr: KtWhenExpression, argIdx: Option[Int]): Ast = { val (astForSubject, finalAstForSubject) = Option(expr.getSubjectExpression) match { case Some(subjectExpression) => val astForSubject = astsForExpression(subjectExpression, Some(1)).headOption.getOrElse(Ast()) @@ -369,12 +357,12 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { }.flatten val switchBlockNode = - blockNode(expr, expr.getEntries.asScala.map(_.getText).mkString("\n"), TypeConstants.any) + blockNode(expr, expr.getEntries.asScala.map(_.getText).mkString("\n"), TypeConstants.Any) val astForBlock = blockAst(switchBlockNode, astsForEntries.toList) val codeForSwitch = Option(expr.getSubjectExpression) .map(_.getText) - .map { text => s"${Constants.when}($text)" } - .getOrElse(Constants.when) + .map { text => s"${Constants.WhenKeyword}($text)" } + .getOrElse(Constants.WhenKeyword) val switchNode = controlStructureNode(expr, ControlStructureTypes.SWITCH, codeForSwitch) val ast = Ast(withArgumentIndex(switchNode, argIdx)).withChildren(List(astForSubject, astForBlock)) // TODO: rewrite this as well @@ -384,10 +372,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { } } - def astForWhenAsExpression(expr: KtWhenExpression, argIdx: Option[Int], argNameMaybe: Option[String])(implicit - typeInfoProvider: TypeInfoProvider - ): Ast = { - + def astForWhenAsExpression(expr: KtWhenExpression, argIdx: Option[Int], argNameMaybe: Option[String]): Ast = { val callNode = withArgumentIndex(NodeBuilders.newOperatorCallNode(".when", ".when", None), argIdx) .argumentName(argNameMaybe) @@ -416,10 +401,10 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { callAst(callNode, List(subjectBlockAst) ++ argAsts) } - private def astForNoArgWhen(expr: KtWhenExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astForNoArgWhen(expr: KtWhenExpression): Ast = { assert(expr.getSubjectExpression == null) - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) var elseAst: Ast = Ast() // Initialize this as `Ast()` instead of `null`, as there is no guarantee of else block // In reverse order than expr.getEntries since that is the order @@ -449,7 +434,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { logger.debug( s"Creating empty AST node for unknown condition expression `${cond.getClass}` with text `${cond.getText}`." ) - Seq(Ast(unknownNode(expr, Option(expr).map(_.getText).getOrElse(Constants.codePropUndefinedValue)))) + Seq(Ast(unknownNode(expr, Option(expr).map(_.getText).getOrElse(Constants.CodePropUndefinedValue)))) case None => // This is the 'else' branch of 'when'. // and thus first in reverse order, if exists @@ -464,7 +449,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val outAst = if (expr.getSubjectExpression != null) { typeInfoProvider.usedAsExpression(expr) match { @@ -477,20 +462,18 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { outAst.withChildren(annotations.map(astForAnnotationEntry)) } - private def astsForWhenEntry(entry: KtWhenEntry, argIdx: Int)(implicit - typeInfoProvider: TypeInfoProvider - ): Seq[Ast] = { + private def astsForWhenEntry(entry: KtWhenEntry, argIdx: Int): Seq[Ast] = { // TODO: get all conditions with entry.getConditions() val name = - if (entry.getElseKeyword == null) Constants.defaultCaseNode - else s"${Constants.caseNodePrefix}$argIdx" - val jumpNode = jumpTargetNode(entry, name, entry.getText, Some(Constants.caseNodeParserTypeName)) + if (entry.getElseKeyword == null) Constants.DefaultCaseNode + else s"${Constants.CaseNodePrefix}$argIdx" + val jumpNode = jumpTargetNode(entry, name, entry.getText, Some(Constants.CaseNodeParserTypeName)) .argumentIndex(argIdx) val exprNode = astsForExpression(entry.getExpression, Some(argIdx + 1)).headOption.getOrElse(Ast()) Seq(Ast(jumpNode), exprNode) } - private def astForTryAsStatement(expr: KtTryExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = { + private def astForTryAsStatement(expr: KtTryExpression): Ast = { val tryAst = astsForExpression(expr.getTryBlock, None).headOption.getOrElse(Ast()) val clauseAsts = expr.getCatchClauses.asScala.toSeq.map { catchClause => val catchNode = controlStructureNode(catchClause, ControlStructureTypes.CATCH, catchClause.getText) @@ -513,10 +496,10 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { val typeFullName = registerType( // TODO: remove the `last` - typeInfoProvider.expressionType(expr.getTryBlock.getStatements.asScala.last, TypeConstants.any) + exprTypeFullName(expr.getTryBlock.getStatements.asScala.last).getOrElse(TypeConstants.Any) ) val tryBlockAst = astsForExpression(expr.getTryBlock, None).headOption.getOrElse(Ast()) val clauseAsts = expr.getCatchClauses.asScala.toSeq.flatMap { entry => @@ -537,7 +520,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { argIdx: Option[Int], argNameMaybe: Option[String], annotations: Seq[KtAnnotationEntry] = Seq() - )(implicit typeInfoProvider: TypeInfoProvider): Ast = { + ): Ast = { if (KtPsiUtil.isStatement(expr)) astForTryAsStatement(expr) else astForTryAsExpression(expr, argIdx, argNameMaybe, annotations) } @@ -560,8 +543,8 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { localsForCaptures: List[NewLocal] = List(), implicitReturnAroundLastStatement: Boolean = false, preStatements: Option[Seq[Ast]] = None - )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { - val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) + ): Seq[Ast] = { + val typeFullName = registerType(exprTypeFullName(expr).getOrElse(TypeConstants.Any)) val node = withArgumentIndex( blockNode(expr, expr.getStatements.asScala.map(_.getText).mkString("\n"), typeFullName), @@ -582,7 +565,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { val lastStatementAstWithTail = if (implicitReturnAroundLastStatement && statements.nonEmpty) { - val _returnNode = returnNode(statements.last, Constants.retCode) + val _returnNode = returnNode(statements.last, Constants.RetCode) val astsForLastStatement = astsForExpression(statements.last, Some(1)) if (astsForLastStatement.isEmpty) (Seq(), None) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/BindingContextUtils.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/BindingContextUtils.scala new file mode 100644 index 000000000000..a84460bf68b0 --- /dev/null +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/BindingContextUtils.scala @@ -0,0 +1,109 @@ +package io.joern.kotlin2cpg.ast + +import org.jetbrains.kotlin.descriptors.annotations.AnnotationDescriptor +import org.jetbrains.kotlin.descriptors.{ + ClassDescriptor, + ConstructorDescriptor, + DeclarationDescriptor, + FunctionDescriptor, + TypeAliasDescriptor, + VariableDescriptor +} +import org.jetbrains.kotlin.psi.{ + KtAnnotationEntry, + KtClassOrObject, + KtConstructor, + KtDestructuringDeclarationEntry, + KtExpression, + KtFunctionLiteral, + KtNamedFunction, + KtParameter, + KtProperty, + KtReferenceExpression, + KtTypeAlias, + KtTypeReference +} +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall +import org.jetbrains.kotlin.types.KotlinType +import org.jetbrains.kotlin.types.error.ErrorType + +import scala.jdk.CollectionConverters.* + +class BindingContextUtils(val bindingContext: BindingContext) { + + def getClassDesc(classAst: KtClassOrObject): ClassDescriptor = { + bindingContext.get(BindingContext.CLASS, classAst) + } + + def getFunctionDesc(functionAst: KtNamedFunction): FunctionDescriptor = { + bindingContext.get(BindingContext.FUNCTION, functionAst) + } + + def getFunctionDesc(functionLiteralAst: KtFunctionLiteral): FunctionDescriptor = { + bindingContext.get(BindingContext.FUNCTION, functionLiteralAst) + } + + def getConstructorDesc(constructorAst: KtConstructor[?]): ConstructorDescriptor = { + bindingContext.get(BindingContext.CONSTRUCTOR, constructorAst) + } + + def getCalledFunctionDesc(destructuringAst: KtDestructuringDeclarationEntry): Option[FunctionDescriptor] = { + val resolvedCall = Option(bindingContext.get(BindingContext.COMPONENT_RESOLVED_CALL, destructuringAst)) + resolvedCall.map(_.getResultingDescriptor) + } + + def getCalledFunctionDesc(expressionAst: KtExpression): Option[FunctionDescriptor] = { + val call = Option(bindingContext.get(BindingContext.CALL, expressionAst)) + val resolvedCall = call.flatMap(call => Option(bindingContext.get(BindingContext.RESOLVED_CALL, call))) + resolvedCall.map(_.getResultingDescriptor).collect { case functionDesc: FunctionDescriptor => functionDesc } + } + + def getAmbiguousCalledFunctionDescs(expression: KtExpression): collection.Seq[FunctionDescriptor] = { + val descriptors = bindingContext.get(BindingContext.AMBIGUOUS_REFERENCE_TARGET, expression) + if (descriptors == null) { return Seq.empty } + descriptors.asScala.toSeq.collect { case functionDescriptor: FunctionDescriptor => functionDescriptor } + } + + def getResolvedCallDesc(expr: KtExpression): Option[ResolvedCall[?]] = { + val call = Option(bindingContext.get(BindingContext.CALL, expr)) + val resolvedCall = call.flatMap(call => Option(bindingContext.get(BindingContext.RESOLVED_CALL, call))) + resolvedCall + } + + def getVariableDesc(param: KtParameter): Option[VariableDescriptor] = { + Option(bindingContext.get(BindingContext.VALUE_PARAMETER, param)) + } + + def getVariableDesc(entry: KtDestructuringDeclarationEntry): Option[VariableDescriptor] = { + Option(bindingContext.get(BindingContext.VARIABLE, entry)) + } + + def getVariableDesc(property: KtProperty): Option[VariableDescriptor] = { + Option(bindingContext.get(BindingContext.VARIABLE, property)) + } + + def getTypeAliasDesc(typeAlias: KtTypeAlias): TypeAliasDescriptor = { + bindingContext.get(BindingContext.TYPE_ALIAS, typeAlias) + } + + def getAnnotationDesc(entry: KtAnnotationEntry): AnnotationDescriptor = { + bindingContext.get(BindingContext.ANNOTATION, entry) + } + + def getDeclDesc(nameRefExpr: KtReferenceExpression): Option[DeclarationDescriptor] = { + Option(bindingContext.get(BindingContext.REFERENCE_TARGET, nameRefExpr)) + } + + def getExprType(expr: KtExpression): Option[KotlinType] = { + Option(bindingContext.get(BindingContext.EXPRESSION_TYPE_INFO, expr)) + .flatMap(typeInfo => Option(typeInfo.getType)) + } + + def getTypeRefType(typeRef: KtTypeReference): Option[KotlinType] = { + Option(bindingContext.get(BindingContext.TYPE, typeRef)) match { + case Some(_: ErrorType) => None + case other => other + } + } +} diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/compiler/CompilerAPI.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/compiler/CompilerAPI.scala index 26fa86253dbf..537a274da52f 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/compiler/CompilerAPI.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/compiler/CompilerAPI.scala @@ -45,6 +45,9 @@ object CompilerAPI { logger.warn(s"Path to dependency does not point to existing file `${path.path}`.") } } else { + // We have to copy the resource file to a proper file in the file system in order + // to satisfy the requirements of `JvmClassPathRoot` which expects a proper `java.io.File` + // which in turn cannot represent files in resources. val resourceStream = getClass.getClassLoader.getResourceAsStream(path.path) if (resourceStream != null) { val tempFile = File.createTempFile("kotlin2cpgDependencies", "") diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/files/SourceFilesPicker.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/files/SourceFilesPicker.scala index b997c92103fa..1ff9e13e66b4 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/files/SourceFilesPicker.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/files/SourceFilesPicker.scala @@ -7,7 +7,7 @@ import org.slf4j.LoggerFactory object SourceFilesPicker { private val logger = LoggerFactory.getLogger(getClass) - private val substringsToFilterFor = List( + private val SubstringsToFilterFor = List( ".idea", "target", "build", @@ -22,7 +22,7 @@ object SourceFilesPicker { ) def shouldFilter(fileName: String): Boolean = { - val containsUnwantedSubstring = substringsToFilterFor.exists(fileName.contains) + val containsUnwantedSubstring = SubstringsToFilterFor.exists(fileName.contains) val isAndroidLayoutXml = fileName.endsWith("xml") && (fileName.contains("drawable") || fileName.contains("layout")) val containsSrcTest = fileName.contains("src/test") diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavasrcInterop.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavaSrcInterop.scala similarity index 68% rename from joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavasrcInterop.scala rename to joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavaSrcInterop.scala index 1f67fb891864..77bbb45e1f85 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavasrcInterop.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/interop/JavaSrcInterop.scala @@ -4,9 +4,9 @@ import io.joern.javasrc2cpg.passes.{AstCreationPass => JavaSrcAstCreationPass} import io.joern.javasrc2cpg.JavaSrc2Cpg import io.shiftleft.codepropertygraph.generated.Cpg -object JavasrcInterop { +object JavaSrcInterop { def astCreationPass(inputPath: String, paths: List[String], cpg: Cpg): JavaSrcAstCreationPass = { - val javasrcConfig = JavaSrc2Cpg.DefaultConfig.withInputPath(inputPath) - new JavaSrcAstCreationPass(javasrcConfig, cpg, Some(paths)) + val javaSrcConfig = JavaSrc2Cpg.DefaultConfig.withInputPath(inputPath) + new JavaSrcAstCreationPass(javaSrcConfig, cpg, Some(paths)) } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/Service.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/Service.scala index b01a4e86ecd5..e3ca828a56a2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/Service.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/Service.scala @@ -5,11 +5,9 @@ import org.slf4j.LoggerFactory class Service(url: String) { private val logger = LoggerFactory.getLogger(getClass) - private val findUrl: String = url + "/find" - def fetchDependencyCoordinates(imports: Seq[String]): Seq[String] = { try { - val resp = requests.get(findUrl, params = Map("names" -> imports.mkString(","))) + val resp = requests.get(url + "/find", params = Map("names" -> imports.mkString(","))) if (resp.statusCode == 200) { val got = ujson.read(resp.bytes) got("matches") match { diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/UsesService.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/UsesService.scala index 69c7ce4d5bba..7ede3b104970 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/UsesService.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/jar4import/UsesService.scala @@ -4,7 +4,7 @@ import com.squareup.tools.maven.resolution.ArtifactResolver import io.joern.kotlin2cpg.Kotlin2Cpg import org.slf4j.LoggerFactory -import java.net.{MalformedURLException, URL} +import java.net.{MalformedURLException, URI} trait UsesService { this: Kotlin2Cpg => @@ -12,7 +12,7 @@ trait UsesService { this: Kotlin2Cpg => protected def reachableServiceMaybe(serviceUrl: String): Option[Service] = { try { - val url = new URL(serviceUrl) + val url = new URI(serviceUrl).toURL val healthResponse = requests.get(url.toString + "/health") if (healthResponse.statusCode != 200) { println(s"The jar4import service at `${url.toString}` did not respond with 200 on the `/health` endpoint.") diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/AstCreationPass.scala index cd9f367ff0a1..dc3772204ae0 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/AstCreationPass.scala @@ -2,16 +2,16 @@ package io.joern.kotlin2cpg.passes import io.joern.kotlin2cpg.KtFileWithMeta import io.joern.kotlin2cpg.ast.AstCreator -import io.joern.kotlin2cpg.types.TypeInfoProvider import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Global import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.passes.ForkJoinParallelCpgPass +import org.jetbrains.kotlin.resolve.BindingContext import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.EnumerationHasAsScala -class AstCreationPass(filesWithMeta: Iterable[KtFileWithMeta], typeInfoProvider: TypeInfoProvider, cpg: Cpg)(implicit +class AstCreationPass(filesWithMeta: Iterable[KtFileWithMeta], bindingContext: BindingContext, cpg: Cpg)(implicit withSchemaValidation: ValidationMode ) extends ForkJoinParallelCpgPass[KtFileWithMeta](cpg) { @@ -23,7 +23,7 @@ class AstCreationPass(filesWithMeta: Iterable[KtFileWithMeta], typeInfoProvider: override def generateParts(): Array[KtFileWithMeta] = filesWithMeta.toArray override def runOnPart(diffGraph: DiffGraphBuilder, fileWithMeta: KtFileWithMeta): Unit = { - diffGraph.absorb(new AstCreator(fileWithMeta, typeInfoProvider, global).createAst()) + diffGraph.absorb(new AstCreator(fileWithMeta, bindingContext, global).createAst()) logger.debug(s"AST created for file at `${fileWithMeta.f.getVirtualFilePath}`.") } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/DependenciesFromMavenCoordinatesPass.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/DependenciesFromMavenCoordinatesPass.scala index 18a210fa1ffa..a3ac5a4dd464 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/DependenciesFromMavenCoordinatesPass.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/DependenciesFromMavenCoordinatesPass.scala @@ -23,11 +23,11 @@ org.springframework:spring-context:6.0.7 ``` */ class DependenciesFromMavenCoordinatesPass(coordinates: Seq[String], cpg: Cpg) extends CpgPass(cpg) { - private val keyValPattern: Regex = "^([^:]+):([^:]+):([^:]+)$".r + private val KeyValPattern: Regex = "^([^:]+):([^:]+):([^:]+)$".r override def run(dstGraph: DiffGraphBuilder): Unit = { coordinates.foreach { coordinate => - for (patternMatch <- keyValPattern.findAllMatchIn(coordinate)) { + for (patternMatch <- KeyValPattern.findAllMatchIn(coordinate)) { val groupId = patternMatch.group(1) val name = patternMatch.group(2) val version = patternMatch.group(3) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeHintCallLinker.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeHintCallLinker.scala index 5c688e3bbc79..31ed0b375ce5 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeHintCallLinker.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeHintCallLinker.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.Defines import io.joern.x2cpg.passes.frontend.XTypeHintCallLinker import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPassGenerator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPassGenerator.scala index 9b0960fcc2e3..3879f39a124b 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPassGenerator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPassGenerator.scala @@ -7,7 +7,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class KotlinTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPassGenerator[File](cpg, config) { @@ -41,7 +41,7 @@ private class RecoverForKotlinFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder val alias = i.importedAs.getOrElse("") val fullName = i.importedEntity.getOrElse("") - if (alias != Constants.wildcardImportName) { + if (alias != Constants.WildcardImportName) { symbolTable.append(CallAlias(alias), fullName) symbolTable.append(LocalVar(alias), fullName) } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/ContentSourcesPicker.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/ContentSourcesPicker.scala index 45e3bc225e60..15631691c37e 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/ContentSourcesPicker.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/ContentSourcesPicker.scala @@ -1,7 +1,6 @@ package io.joern.kotlin2cpg.types -import better.files.{File => BFile} -import io.joern.kotlin2cpg.DefaultContentRootJarPath +import better.files.File object ContentSourcesPicker { @@ -19,7 +18,7 @@ object ContentSourcesPicker { // `Seq("dir1/dir2/dir3")` and nothing else. def dirsForRoot(rootDir: String): Seq[String] = { - val dir = BFile(rootDir) + val dir = File(rootDir) val hasSubDirs = dir.list.exists(_.isDirectory) if (!hasSubDirs) { return Seq(rootDir) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/DefaultTypeInfoProvider.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/DefaultTypeInfoProvider.scala deleted file mode 100644 index 523b5cb0c9e1..000000000000 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/DefaultTypeInfoProvider.scala +++ /dev/null @@ -1,1015 +0,0 @@ -package io.joern.kotlin2cpg.types - -import io.joern.kotlin2cpg.psi.PsiUtils -import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.Operators -import kotlin.reflect.jvm.internal.impl.load.java.descriptors.JavaClassConstructorDescriptor -import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment -import org.jetbrains.kotlin.cli.jvm.compiler.KotlinToJVMBytecodeCompiler -import org.jetbrains.kotlin.cli.jvm.compiler.NoScopeRecordCliBindingTrace -import org.jetbrains.kotlin.com.intellij.util.keyFMap.KeyFMap -import org.jetbrains.kotlin.descriptors.DeclarationDescriptor -import org.jetbrains.kotlin.descriptors.DescriptorVisibility -import org.jetbrains.kotlin.descriptors.FunctionDescriptor -import org.jetbrains.kotlin.descriptors.ValueDescriptor -import org.jetbrains.kotlin.descriptors.ValueParameterDescriptor -import org.jetbrains.kotlin.descriptors.impl.ClassConstructorDescriptorImpl -import org.jetbrains.kotlin.descriptors.impl.EnumEntrySyntheticClassDescriptor -import org.jetbrains.kotlin.descriptors.impl.LazyPackageViewDescriptorImpl -import org.jetbrains.kotlin.descriptors.impl.PropertyDescriptorImpl -import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptorImpl -import org.jetbrains.kotlin.descriptors.CallableDescriptor -import org.jetbrains.kotlin.descriptors.Modality -import org.jetbrains.kotlin.load.java.`lazy`.descriptors.LazyJavaClassDescriptor -import org.jetbrains.kotlin.load.java.sources.JavaSourceElement -import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaMethod -import org.jetbrains.kotlin.psi.KtAnnotationEntry -import org.jetbrains.kotlin.psi.KtArrayAccessExpression -import org.jetbrains.kotlin.psi.KtBinaryExpression -import org.jetbrains.kotlin.psi.KtCallExpression -import org.jetbrains.kotlin.psi.KtClassBody -import org.jetbrains.kotlin.psi.KtClassLiteralExpression -import org.jetbrains.kotlin.psi.KtClassOrObject -import org.jetbrains.kotlin.psi.KtDestructuringDeclarationEntry -import org.jetbrains.kotlin.psi.KtElement -import org.jetbrains.kotlin.psi.KtExpression -import org.jetbrains.kotlin.psi.KtFile -import org.jetbrains.kotlin.psi.KtLambdaExpression -import org.jetbrains.kotlin.psi.KtNameReferenceExpression -import org.jetbrains.kotlin.psi.KtNamedFunction -import org.jetbrains.kotlin.psi.KtParameter -import org.jetbrains.kotlin.psi.KtPrimaryConstructor -import org.jetbrains.kotlin.psi.KtProperty -import org.jetbrains.kotlin.psi.KtPsiUtil -import org.jetbrains.kotlin.psi.KtQualifiedExpression -import org.jetbrains.kotlin.psi.KtSecondaryConstructor -import org.jetbrains.kotlin.psi.KtSuperExpression -import org.jetbrains.kotlin.psi.KtThisExpression -import org.jetbrains.kotlin.psi.KtTypeAlias -import org.jetbrains.kotlin.psi.KtTypeReference -import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.resolve.DescriptorUtils -import org.jetbrains.kotlin.resolve.DescriptorUtils.getSuperclassDescriptors -import org.jetbrains.kotlin.resolve.`lazy`.descriptors.LazyClassDescriptor -import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedClassDescriptor -import org.jetbrains.kotlin.types.TypeUtils -import org.jetbrains.kotlin.types.error.ErrorType -import org.slf4j.LoggerFactory - -import scala.jdk.CollectionConverters.CollectionHasAsScala -import scala.util.Failure -import scala.util.Success -import scala.util.Try -import scala.util.control.NonFatal - -class DefaultTypeInfoProvider(environment: KotlinCoreEnvironment, typeRenderer: TypeRenderer = new TypeRenderer()) - extends TypeInfoProvider(typeRenderer) { - private val logger = LoggerFactory.getLogger(getClass) - - import DefaultTypeInfoProvider.* - - val bindingContext: BindingContext = { - Try { - logger.info("Running Kotlin compiler analysis...") - val t0 = System.currentTimeMillis() - val analysisResult = KotlinToJVMBytecodeCompiler.INSTANCE.analyze(environment) - val t1 = System.currentTimeMillis() - logger.info(s"Kotlin compiler analysis finished in `${t1 - t0}` ms.") - analysisResult - } match { - case Success(analysisResult) => analysisResult.getBindingContext - case Failure(exc) => - logger.error(s"Kotlin compiler analysis failed with exception `${exc.toString}`:`${exc.getMessage}`.", exc) - BindingContext.EMPTY - } - } - - private def isValidRender(render: String): Boolean = { - !render.contains("ERROR") - } - - def anySignature(args: Seq[Any]): String = { - val argsSignature = - if (args.isEmpty) "" - else if (args.size == 1) TypeConstants.any - else s"${TypeConstants.any}${s",${TypeConstants.any}" * (args.size - 1)}" - s"${TypeConstants.any}($argsSignature)" - } - - def usedAsExpression(expr: KtExpression): Option[Boolean] = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.USED_AS_EXPRESSION.getKey)).map(_.booleanValue()) - } - - def fullName(expr: KtTypeAlias, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.TYPE_ALIAS.getKey)) - .map(typeRenderer.renderFqNameForDesc) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def visibility(fn: KtNamedFunction): Option[DescriptorVisibility] = { - val mapForEntity = bindingsForEntity(bindingContext, fn) - Option(mapForEntity.get(BindingContext.FUNCTION.getKey)) - .map(_.getVisibility) - } - - def modality(fn: KtNamedFunction): Option[Modality] = { - val mapForEntity = bindingsForEntity(bindingContext, fn) - Option(mapForEntity.get(BindingContext.FUNCTION.getKey)) - .map(_.getModality) - } - - def modality(ktClass: KtClassOrObject): Option[Modality] = { - val mapForEntity = bindingsForEntity(bindingContext, ktClass) - Option(mapForEntity.get(BindingContext.CLASS.getKey)) - .map(_.getModality) - } - - def containingTypeDeclFullName(ktFn: KtNamedFunction, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, ktFn) - Option(mapForEntity.get(BindingContext.FUNCTION.getKey)) - .map { fnDesc => - if (DescriptorUtils.isExtension(fnDesc)) - typeRenderer.render(fnDesc.getExtensionReceiverParameter.getType) - else - typeRenderer.renderFqNameForDesc(fnDesc.getContainingDeclaration) - } - .getOrElse(defaultValue) - } - - def isStaticMethodCall(expr: KtQualifiedExpression): Boolean = { - resolvedCallDescriptor(expr) - .map(_.getSource) - .collect { case s: JavaSourceElement => s } - .map(_.getJavaElement) - .collect { case bjm: BinaryJavaMethod => bjm } - .exists(_.isStatic) - } - - def fullNameWithSignature(expr: KtDestructuringDeclarationEntry, defaultValue: (String, String)): (String, String) = { - Option(bindingContext.get(BindingContext.COMPONENT_RESOLVED_CALL, expr)) - .map { resolvedCall => - val fnDesc = resolvedCall.getResultingDescriptor - val relevantDesc = - if (!fnDesc.isActual && fnDesc.getOverriddenDescriptors.asScala.nonEmpty) - fnDesc.getOverriddenDescriptors.asScala.toList.head - else fnDesc - val renderedFqName = typeRenderer.renderFqNameForDesc(relevantDesc) - val returnTypeFullName = renderedReturnType(relevantDesc.getOriginal) - - val renderedParameterTypes = - relevantDesc.getValueParameters.asScala.toSeq - .map(renderTypeForParameterDesc) - .mkString(",") - val signature = s"$returnTypeFullName($renderedParameterTypes)" - val fullName = s"$renderedFqName:$signature" - - if (!isValidRender(fullName) || !isValidRender(signature)) defaultValue - else (fullName, signature) - } - .getOrElse(defaultValue) - } - - def isRefToCompanionObject(expr: KtNameReferenceExpression): Boolean = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity) - .map(_.getKeys) - .exists(_.contains(BindingContext.SHORT_REFERENCE_TO_COMPANION_OBJECT.getKey)) - } - - def typeFullName(expr: KtDestructuringDeclarationEntry, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.VARIABLE.getKey)) - .map { desc => typeRenderer.render(desc.getType) } - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def typeFullName(expr: KtTypeReference, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.TYPE.getKey)) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def typeFullName(expr: KtCallExpression, defaultValue: String): String = { - resolvedCallDescriptor(expr) - .map(_.getOriginal) - .map { originalDesc => - val relevantDesc = - if (!originalDesc.isActual && originalDesc.getOverriddenDescriptors.asScala.nonEmpty) - originalDesc.getOverriddenDescriptors.asScala.toList.head - else originalDesc - if (isConstructorCall(expr).getOrElse(false)) TypeConstants.void - else renderedReturnType(relevantDesc.getOriginal) - } - .getOrElse(defaultValue) - } - - def aliasTypeFullName(expr: KtTypeAlias, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.TYPE_ALIAS.getKey)) - .map(_.getExpandedType) - .filterNot(_.isInstanceOf[ErrorType]) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def returnType(expr: KtNamedFunction, defaultValue: String): String = { - Option(bindingContext.get(BindingContext.FUNCTION, expr)) - .map(_.getReturnType) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def propertyType(expr: KtProperty, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.VARIABLE.getKey)) - .map(_.getType) - .filterNot(_.isInstanceOf[ErrorType]) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse( - Option(expr.getTypeReference) - .map { typeRef => - typeFromImports(typeRef.getText, expr.getContainingKtFile).getOrElse(typeRef.getText) - } - .getOrElse(defaultValue) - ) - } - - def typeFullName(expr: KtClassOrObject, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.CLASS.getKey)) - .map(_.getDefaultType) - .map(typeRenderer.render(_)) - .getOrElse(defaultValue) - } - - def inheritanceTypes(expr: KtClassOrObject, defaultValue: Seq[String]): Seq[String] = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.CLASS.getKey)) - .map(getSuperclassDescriptors) - .filter(_.asScala.nonEmpty) - .map( - _.asScala - .map { superClassDesc => - typeRenderer.render(superClassDesc.getDefaultType) - } - .toList - ) - .getOrElse(defaultValue) - } - - private def anonymousObjectIdx(obj: KtElement): Option[Int] = { - val parentFn = KtPsiUtil.getTopmostParentOfTypes(obj, classOf[KtNamedFunction]) - val containingObj = Option(parentFn).getOrElse(obj.getContainingKtFile) - PsiUtils.objectIdxMaybe(obj, containingObj) - } - - def fullName( - expr: KtClassOrObject, - defaultValue: String, - anonymousCtxMaybe: Option[AnonymousObjectContext] = None - ): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - val nonLocalFullName = Option(mapForEntity.get(BindingContext.CLASS.getKey)) - .map(_.getDefaultType) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - - if (anonymousCtxMaybe.nonEmpty) { - anonymousCtxMaybe - .map { _ => - val fnDescMaybe = Option(mapForEntity.get(BindingContext.CLASS.getKey)) - fnDescMaybe - .map(_.getContainingDeclaration) - .map { containingDecl => - val idxMaybe = anonymousObjectIdx(expr) - val idx = idxMaybe.map(_.toString).getOrElse("nan") - s"${typeRenderer.renderFqNameForDesc(containingDecl.getOriginal).stripSuffix(".")}" + "$object$" + s"$idx" - } - .getOrElse(nonLocalFullName) - } - .getOrElse(nonLocalFullName) - } else if (expr.isLocal) { - val fnDescMaybe = Option(mapForEntity.get(BindingContext.CLASS.getKey)) - fnDescMaybe - .map(_.getContainingDeclaration) - .map { containingDecl => - s"${typeRenderer.renderFqNameForDesc(containingDecl.getOriginal)}.${expr.getName}" - } - .getOrElse(nonLocalFullName) - } else nonLocalFullName - } - - def isCompanionObject(expr: KtClassOrObject): Boolean = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.CLASS.getKey)).exists(DescriptorUtils.isCompanionObject(_)) - } - - def typeFullName(expr: KtParameter, defaultValue: String): String = { - val mapForEntity = bindingsForEntity(bindingContext, expr) - Option(mapForEntity.get(BindingContext.VALUE_PARAMETER.getKey)) - .map(_.getType) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def expressionType(expr: KtExpression, defaultValue: String): String = { - Option(bindingContext.get(BindingContext.EXPRESSION_TYPE_INFO, expr)) - .flatMap(tpeInfo => Option(tpeInfo.getType)) - .map(typeRenderer.render(_)) - .filter(isValidRender) - .getOrElse(defaultValue) - } - - def fullNameWithSignature(expr: KtClassLiteralExpression, defaultValue: (String, String)): (String, String) = { - Option(bindingContext.get(BindingContext.EXPRESSION_TYPE_INFO, expr)) - .map(_.getType) - .map(_.getArguments.asScala) - .filter(_.nonEmpty) - .map { typeArguments => - val firstTypeArg = typeArguments.toList.head - val rendered = typeRenderer.render(firstTypeArg.getType) - val retType = expressionType(expr, TypeConstants.any) - val signature = s"$retType()" - val fullName = s"$rendered.${TypeConstants.classLiteralReplacementMethodName}:$signature" - (fullName, signature) - } - .getOrElse(defaultValue) - } - - private def subexpressionForResolvedCallInfo(expr: KtExpression): KtExpression = { - expr match { - case typedExpr: KtCallExpression => - Option(typedExpr.getFirstChild) - .collect { case expr: KtExpression => expr } - .getOrElse(expr) - case typedExpr: KtQualifiedExpression => - Option(typedExpr.getSelectorExpression) - .collect { case expr: KtCallExpression => expr } - .map(subexpressionForResolvedCallInfo) - .getOrElse(typedExpr) - case typedExpr: KtBinaryExpression => - Option(typedExpr.getChildren.toList(1)) - .collect { case expr: KtExpression => expr } - .getOrElse(expr) - case _ => expr - } - } - - private def resolvedCallDescriptor(expr: KtExpression): Option[FunctionDescriptor] = { - val relevantSubexpression = subexpressionForResolvedCallInfo(expr) - val descMaybe = for { - callForSubexpression <- Option(bindingContext.get(BindingContext.CALL, relevantSubexpression)) - resolvedCallForSubexpression <- Option(bindingContext.get(BindingContext.RESOLVED_CALL, callForSubexpression)) - desc = resolvedCallForSubexpression.getResultingDescriptor - } yield desc - descMaybe.collect { case desc: FunctionDescriptor => desc } - } - - private def isConstructorDescriptor(desc: FunctionDescriptor): Boolean = { - desc match { - case _: JavaClassConstructorDescriptor => true - case _: ClassConstructorDescriptorImpl => true - case _: TypeAliasConstructorDescriptorImpl => true - case _ => false - } - } - - def isConstructorCall(expr: KtExpression): Option[Boolean] = { - expr match { - case _: KtCallExpression | _: KtQualifiedExpression => - resolvedCallDescriptor(expr) match { - case Some(desc) if isConstructorDescriptor(desc) => Some(true) - case _ => Some(false) - } - case _ => Some(false) - } - } - - def fullNameWithSignature(expr: KtCallExpression, defaultValue: (String, String)): (String, String) = { - resolvedCallDescriptor(expr) match { - case Some(desc) => - val originalDesc = desc.getOriginal - val relevantDesc = originalDesc match { - case typedDesc: TypeAliasConstructorDescriptorImpl => - typedDesc.getUnderlyingConstructorDescriptor - case typedDesc: FunctionDescriptor if !typedDesc.isActual => - val overriddenDescriptors = typedDesc.getOverriddenDescriptors.asScala.toList - if (overriddenDescriptors.nonEmpty) overriddenDescriptors.head - else typedDesc - case _ => originalDesc - } - val returnTypeFullName = - if (isConstructorCall(expr).getOrElse(false)) TypeConstants.void - else renderedReturnType(relevantDesc.getOriginal) - val renderedParameterTypes = - relevantDesc.getValueParameters.asScala.toSeq - .map(renderTypeForParameterDesc) - .mkString(",") - val signature = s"$returnTypeFullName($renderedParameterTypes)" - - val renderedFqName = typeRenderer.renderFqNameForDesc(relevantDesc) - val fullName = - if (isConstructorCall(expr).getOrElse(false)) s"$renderedFqName${TypeConstants.initPrefix}:$signature" - else s"$renderedFqName:$signature" - if (!isValidRender(fullName) || !isValidRender(signature)) defaultValue - else (fullName, signature) - case None => - val relevantSubexpression = subexpressionForResolvedCallInfo(expr) - val numArgs = expr.getValueArguments.size - val ambiguousReferences = - Option(bindingContext.get(BindingContext.AMBIGUOUS_REFERENCE_TARGET, relevantSubexpression)) - .map(_.toArray.toSeq.collect { case desc: FunctionDescriptor => desc }) - .getOrElse(Seq()) - val chosenAmbiguousReference = ambiguousReferences.find(_.getValueParameters.size == numArgs) - chosenAmbiguousReference - .map { desc => - val signature = Defines.UnresolvedSignature - val fullName = s"${typeRenderer.renderFqNameForDesc(desc)}:$signature($numArgs)" - (fullName, signature) - } - .getOrElse(defaultValue) - } - } - - def typeFullName(expr: KtBinaryExpression, defaultValue: String): String = { - resolvedCallDescriptor(expr) - .map(_.getOriginal) - .map { desc => typeRenderer.render(desc.getReturnType) } - .getOrElse(defaultValue) - } - - def typeFullName(expr: KtAnnotationEntry, defaultValue: String): String = { - Option(bindingsForEntity(bindingContext, expr)) - .flatMap(_ => Option(bindingContext.get(BindingContext.ANNOTATION, expr))) - .map { desc => typeRenderer.render(desc.getType) } - .getOrElse(defaultValue) - } - - def fullNameWithSignature(expr: KtBinaryExpression, defaultValue: (String, String)): (String, String) = { - resolvedCallDescriptor(expr) - .map { fnDescriptor => - val originalDesc = fnDescriptor.getOriginal - val renderedParameterTypes = - originalDesc.getValueParameters.asScala.toSeq - .map(_.getType) - .map { t => typeRenderer.render(t) } - .mkString(",") - val renderedReturnType = typeRenderer.render(originalDesc.getReturnType) - val signature = s"$renderedReturnType($renderedParameterTypes)" - val fullName = - if (originalDesc.isInstanceOf[ClassConstructorDescriptorImpl]) { - s"$renderedReturnType.${TypeConstants.initPrefix}:$signature" - } else { - val renderedFqName = typeRenderer.renderFqNameForDesc(originalDesc) - s"$renderedFqName:$signature" - } - if (!isValidRender(fullName) || !isValidRender(signature)) defaultValue - else (fullName, signature) - } - .getOrElse(defaultValue) - } - - def containingDeclFullName(expr: KtCallExpression): Option[String] = { - resolvedCallDescriptor(expr) - .map(_.getContainingDeclaration) - .map(typeRenderer.renderFqNameForDesc) - } - - def containingDeclType(expr: KtQualifiedExpression, defaultValue: String): String = { - resolvedCallDescriptor(expr) - .map(_.getContainingDeclaration) - .map(typeRenderer.renderFqNameForDesc) - .getOrElse(defaultValue) - } - - def hasStaticDesc(expr: KtQualifiedExpression): Boolean = { - resolvedCallDescriptor(expr).forall(_.getDispatchReceiverParameter == null) - } - - def bindingKind(expr: KtQualifiedExpression): CallKind = { - val isStaticBasedOnStructure = expr.getReceiverExpression.isInstanceOf[KtSuperExpression] - if (isStaticBasedOnStructure) return CallKind.StaticCall - - val isDynamicBasedOnStructure = expr.getReceiverExpression match { - case _: KtArrayAccessExpression => true - case _: KtThisExpression => true - case _ => false - } - if (isDynamicBasedOnStructure) return CallKind.DynamicCall - - resolvedCallDescriptor(expr) - .map { desc => - val isExtension = DescriptorUtils.isExtension(desc) - val isStatic = DescriptorUtils.isStaticDeclaration(desc) || hasStaticDesc(expr) - - if (isExtension) CallKind.ExtensionCall - else if (isStatic) CallKind.StaticCall - else CallKind.DynamicCall - } - .getOrElse(CallKind.Unknown) - } - - def isExtensionFn(fn: KtNamedFunction): Boolean = { - Option(bindingContext.get(BindingContext.FUNCTION, fn)).exists(DescriptorUtils.isExtension) - } - - private def renderTypeForParameterDesc(p: ValueParameterDescriptor): String = { - val typeUpperBounds = - Option(TypeUtils.getTypeParameterDescriptorOrNull(p.getType)) - .map(_.getUpperBounds) - .map(_.asScala) - .map(_.toList) - .getOrElse(List()) - if (typeUpperBounds.nonEmpty) - typeRenderer.render(typeUpperBounds.head) - else - typeRenderer.render(p.getOriginal.getType) - } - - def fullNameWithSignature(expr: KtQualifiedExpression, defaultValue: (String, String)): (String, String) = { - resolvedCallDescriptor(expr) match { - case Some(fnDescriptor) => - val originalDesc = fnDescriptor.getOriginal - - val renderedFqNameForDesc = typeRenderer.renderFqNameForDesc(fnDescriptor) - val renderedFqNameMaybe = for { - extensionReceiverParam <- Option(originalDesc.getExtensionReceiverParameter) - erpType = extensionReceiverParam.getType - } yield { - val typeUpperBounds = - Option(TypeUtils.getTypeParameterDescriptorOrNull(erpType)) - .map(_.getUpperBounds) - .map(_.asScala) - .map(_.toList) - .getOrElse(List()) - if (erpType.isInstanceOf[ErrorType]) { - s"${Defines.UnresolvedNamespace}.${expr.getName}" - } else { - val rendered = - if (renderedFqNameForDesc.startsWith(TypeConstants.kotlinApplyPrefix)) TypeConstants.javaLangObject - else if (typeUpperBounds.size == 1) { - typeRenderer.render( - typeUpperBounds.head, - shouldMapPrimitiveArrayTypes = false, - unwrapPrimitives = false - ) - } else typeRenderer.render(erpType, shouldMapPrimitiveArrayTypes = false, unwrapPrimitives = false) - s"$rendered.${originalDesc.getName}" - } - } - - val renderedFqName = - Option(originalDesc.getDispatchReceiverParameter) - .map(_.getOriginal) - .map(_.getContainingDeclaration) - .map { objDesc => - if (DescriptorUtils.isAnonymousObject(objDesc)) { - s"${typeRenderer.renderFqNameForDesc(objDesc)}.${originalDesc.getName}" - } else renderedFqNameMaybe.getOrElse(renderedFqNameForDesc) - } - .getOrElse(renderedFqNameMaybe.getOrElse(renderedFqNameForDesc)) - - val renderedParameterTypes = - originalDesc.getValueParameters.asScala.toSeq - .map(renderTypeForParameterDesc) - .mkString(",") - val renderedReturnType = - if (isConstructorDescriptor(originalDesc)) TypeConstants.void - else if (renderedFqNameForDesc.startsWith(TypeConstants.kotlinApplyPrefix)) TypeConstants.javaLangObject - else typeRenderer.render(originalDesc.getReturnType) - - val singleLambdaArgExprMaybe = expr.getSelectorExpression match { - case c: KtCallExpression if c.getLambdaArguments.size() == 1 => - Some(c.getLambdaArguments.get(0).getLambdaExpression) - case _ => None - } - val fullNameSignature = s"$renderedReturnType($renderedParameterTypes)" - val signature = - singleLambdaArgExprMaybe - .map(lambdaInvocationSignature(_, renderedReturnType)) - .getOrElse(fullNameSignature) - (s"$renderedFqName:$fullNameSignature", signature) - case None => - resolvedCallDescriptor(expr.getReceiverExpression) match { - case Some(desc) => - desc match { - case _: ClassConstructorDescriptorImpl | _: TypeAliasConstructorDescriptorImpl => - expr.getSelectorExpression match { - case _: KtNameReferenceExpression => (Operators.fieldAccess, "") - case _ => defaultValue - } - case _ => - val originalDesc = desc.getOriginal - val lhsName = typeRenderer.render(originalDesc.getReturnType) - val name = expr.getSelectorExpression.getFirstChild.getText - val numArgs = expr.getSelectorExpression match { - case c: KtCallExpression => c.getValueArguments.size() - case _ => 0 - } - val signature = s"${Defines.UnresolvedSignature}($numArgs)" - val fullName = s"$lhsName.$name:$signature" - (fullName, signature) - } - case None => defaultValue - } - } - } - - private def lambdaInvocationSignature(expr: KtLambdaExpression, returnType: String): String = { - val hasImplicitParameter = implicitParameterName(expr) - val params = expr.getValueParameters - val paramsString = - if (hasImplicitParameter.nonEmpty) TypeConstants.javaLangObject - else if (params.isEmpty) "" - else if (params.size() == 1) TypeConstants.javaLangObject - else - s"${TypeConstants.javaLangObject}${("," + TypeConstants.javaLangObject) * (expr.getValueParameters.size() - 1)}" - s"$returnType($paramsString)" - } - - def parameterType(parameter: KtParameter, defaultValue: String): String = { - // TODO: add specific test for no binding info of parameter - // triggered by exception in https://github.com/agrosner/DBFlow - // TODO: ...also test cases for non-null binding info for other fns - val render = for { - mapForEntity <- Option(bindingsForEntity(bindingContext, parameter)) - variableDesc <- Option(mapForEntity.get(BindingContext.VALUE_PARAMETER.getKey)) - typeUpperBounds = - Option(TypeUtils.getTypeParameterDescriptorOrNull(variableDesc.getType)) - .map(_.getUpperBounds) - .map(_.asScala) - .map(_.toList) - .getOrElse(List()) - render = - if (typeUpperBounds.nonEmpty) - typeRenderer.render(typeUpperBounds.head) - else - typeRenderer.render(variableDesc.getType) - if isValidRender(render) && !variableDesc.getType.isInstanceOf[ErrorType] - } yield render - - render.getOrElse( - Option(parameter.getTypeReference) - .map { typeRef => - typeFromImports(typeRef.getText, parameter.getContainingKtFile).getOrElse(typeRef.getText) - } - .getOrElse(defaultValue) - ) - } - - def destructuringEntryType(expr: KtDestructuringDeclarationEntry, defaultValue: String): String = { - val render = for { - mapForEntity <- Option(bindingsForEntity(bindingContext, expr)) - variableDesc <- Option(mapForEntity.get(BindingContext.VARIABLE.getKey)) - render = typeRenderer.render(variableDesc.getType) - if isValidRender(render) && !variableDesc.getType.isInstanceOf[ErrorType] - } yield render - render.getOrElse(defaultValue) - } - - def hasApplyOrAlsoScopeFunctionParent(expr: KtLambdaExpression): Boolean = { - expr.getParent.getParent match { - case callExpr: KtCallExpression => - resolvedCallDescriptor(callExpr) match { - case Some(desc) => - val rendered = typeRenderer.renderFqNameForDesc(desc.getOriginal) - rendered.startsWith(TypeConstants.kotlinApplyPrefix) || rendered.startsWith(TypeConstants.kotlinAlsoPrefix) - case _ => false - } - case _ => false - } - } - - def returnTypeFullName(expr: KtLambdaExpression): String = { - TypeConstants.javaLangObject - } - - def fullNameWithSignature(expr: KtLambdaExpression, lambdaName: String): (String, String) = { - val containingFile = expr.getContainingKtFile - val fileName = containingFile.getName - val packageName = containingFile.getPackageFqName.toString - val astDerivedFullName = s"$packageName:.$lambdaName()" - val astDerivedSignature = anySignature(expr.getValueParameters.asScala.toList) - - val render = for { - mapForEntity <- Option(bindingsForEntity(bindingContext, expr)) - typeInfo <- Option(mapForEntity.get(BindingContext.EXPRESSION_TYPE_INFO.getKey)) - theType = typeInfo.getType - } yield { - val constructorDesc = theType.getConstructor.getDeclarationDescriptor - val constructorType = constructorDesc.getDefaultType - val args = constructorType.getArguments.asScala.drop(1) - - val renderedRetType = - args.lastOption - .map { t => typeRenderer.render(t.getType) } - .getOrElse(TypeConstants.javaLangObject) - val renderedArgs = - if (args.isEmpty) "" - else if (args.size == 1) TypeConstants.javaLangObject - else s"${TypeConstants.javaLangObject}${("," + TypeConstants.javaLangObject) * (args.size - 1)}" - val signature = s"$renderedRetType($renderedArgs)" - val fullName = s"$packageName..$lambdaName:$signature" - (fullName, signature) - } - render.getOrElse((astDerivedFullName, astDerivedSignature)) - } - - private def renderedReturnType(fnDesc: FunctionDescriptor): String = { - val returnT = fnDesc.getReturnType.getConstructor.getDeclarationDescriptor.getDefaultType - val typeParams = fnDesc.getTypeParameters.asScala.toList - - val typesInTypeParams = typeParams.map(_.getDefaultType.getConstructor.getDeclarationDescriptor.getDefaultType) - val hasReturnTypeFromTypeParams = typesInTypeParams.contains(returnT) - if (hasReturnTypeFromTypeParams) { - if (returnT.getConstructor.getSupertypes.asScala.nonEmpty) { - val firstSuperType = returnT.getConstructor.getSupertypes.asScala.toList.head - typeRenderer.render(firstSuperType) - } else { - val renderedReturnT = typeRenderer.render(returnT) - if (renderedReturnT == TypeConstants.tType) TypeConstants.javaLangObject - else renderedReturnT - } - } else { - typeRenderer.render(fnDesc.getReturnType) - } - } - - def fullNameWithSignature(expr: KtSecondaryConstructor, defaultValue: (String, String)): (String, String) = { - val fnDesc = Option(bindingContext.get(BindingContext.CONSTRUCTOR, expr)) - val paramTypeNames = expr.getValueParameters.asScala.map { parameter => - val explicitTypeFullName = - Option(parameter.getTypeReference) - .map(_.getText) - .map(typeRenderer.stripped) - .getOrElse(Defines.UnresolvedNamespace) - // TODO: return all the parameter types in this fn for registration, otherwise they will be missing - parameterType(parameter, explicitTypeFullName) - } - val paramListSignature = s"(${paramTypeNames.mkString(",")})" - val methodName = fnDesc - .map(desc => s"${typeRenderer.renderFqNameForDesc(desc)}${TypeConstants.initPrefix}") - .getOrElse(s"${Defines.UnresolvedNamespace}.${TypeConstants.initPrefix}") - val signature = s"${TypeConstants.void}$paramListSignature" - val fullname = s"$methodName:$signature" - (fullname, signature) - } - - def fullNameWithSignature(expr: KtPrimaryConstructor, defaultValue: (String, String)): (String, String) = { - // if not explicitly defined, the primary ctor will be `null` - if (expr == null) { - defaultValue - } else { - val paramTypeNames = expr.getValueParameters.asScala.map { parameter => - val explicitTypeFullName = Option(parameter.getTypeReference) - .map(_.getText) - .getOrElse(Defines.UnresolvedNamespace) - // TODO: return all the parameter types in this fn for registration, otherwise they will be missing - parameterType(parameter, typeRenderer.stripped(explicitTypeFullName)) - } - val paramListSignature = s"(${paramTypeNames.mkString(",")})" - val methodName = Option(bindingContext.get(BindingContext.CONSTRUCTOR, expr)) - .map { info => s"${typeRenderer.renderFqNameForDesc(info)}${TypeConstants.initPrefix}" } - .getOrElse(s"${Defines.UnresolvedNamespace}.${TypeConstants.initPrefix}") - val signature = s"${TypeConstants.void}$paramListSignature" - val fullname = s"$methodName:$signature" - (fullname, signature) - } - } - - def fullNameWithSignatureAsLambda(expr: KtNamedFunction, lambdaName: String): (String, String) = { - val containingFile = expr.getContainingKtFile - val fileName = containingFile.getName - val packageName = containingFile.getPackageFqName.toString - val astDerivedFullName = s"$packageName:.$lambdaName()" - val astDerivedSignature = anySignature(expr.getValueParameters.asScala.toList) - - val render = for { - mapForEntity <- Option(bindingsForEntity(bindingContext, expr)) - typeInfo <- Option(mapForEntity.get(BindingContext.EXPRESSION_TYPE_INFO.getKey)) - theType = typeInfo.getType - } yield { - val constructorDesc = theType.getConstructor.getDeclarationDescriptor - val constructorType = constructorDesc.getDefaultType - val args = constructorType.getArguments.asScala.drop(1) - - val renderedRetType = - args.lastOption - .map { t => typeRenderer.render(t.getType) } - .getOrElse(TypeConstants.javaLangObject) - val renderedArgs = - if (args.isEmpty) "" - else if (args.size == 1) TypeConstants.javaLangObject - else s"${TypeConstants.javaLangObject}${("," + TypeConstants.javaLangObject) * (args.size - 1)}" - val signature = s"$renderedRetType($renderedArgs)" - val fullName = s"$packageName..$lambdaName:$signature" - (fullName, signature) - } - render.getOrElse((astDerivedFullName, astDerivedSignature)) - } - - def fullNameWithSignature(expr: KtNamedFunction, defaultValue: (String, String)): (String, String) = { - val fnDescMaybe = Option(bindingContext.get(BindingContext.FUNCTION, expr)) - val returnTypeFullName = fnDescMaybe.map(renderedReturnType(_)).getOrElse(Defines.UnresolvedNamespace) - val paramTypeNames = expr.getValueParameters.asScala.map { parameter => - val explicitTypeFullName = - Option(parameter.getTypeReference) - .map(_.getText) - .getOrElse(Defines.UnresolvedNamespace) - // TODO: return all the parameter types in this fn for registration, otherwise they will be missing - parameterType(parameter, typeRenderer.stripped(explicitTypeFullName)) - } - val paramListSignature = s"(${paramTypeNames.mkString(",")})" - - val methodName = for { - fnDesc <- fnDescMaybe - extensionReceiverParam <- Option(fnDesc.getExtensionReceiverParameter) - erpType = extensionReceiverParam.getType - } yield { - if (erpType.isInstanceOf[ErrorType]) { - s"${Defines.UnresolvedNamespace}.${expr.getName}" - } else { - val theType = fnDescMaybe.get.getExtensionReceiverParameter.getType - val renderedType = typeRenderer.render(theType) - s"$renderedType.${expr.getName}" - } - } - - val nameNoParent = s"${methodName.getOrElse(expr.getFqName)}" - val name = if (expr.getContext.isInstanceOf[KtClassBody] || expr.isLocal) { - fnDescMaybe - .map(_.getContainingDeclaration) - .map { containingDecl => - s"${typeRenderer.renderFqNameForDesc(containingDecl.getOriginal)}.${expr.getName}" - } - .getOrElse(nameNoParent) - } else nameNoParent - val signature = s"$returnTypeFullName$paramListSignature" - val fullname = s"$name:$signature" - (fullname, signature) - } - - def isReferenceToClass(expr: KtNameReferenceExpression): Boolean = { - descriptorForNameReference(expr).exists { - case _: LazyJavaClassDescriptor => true - case _: LazyClassDescriptor => true - case _ => false - } - } - - private def descriptorForNameReference(expr: KtNameReferenceExpression): Option[DeclarationDescriptor] = { - Option(bindingsForEntity(bindingContext, expr)) - .map(_ => bindingContext.get(BindingContext.REFERENCE_TARGET, expr)) - } - - def referenceTargetTypeFullName(expr: KtNameReferenceExpression, defaultValue: String): String = { - descriptorForNameReference(expr) - .collect { case desc: PropertyDescriptorImpl => typeRenderer.renderFqNameForDesc(desc.getContainingDeclaration) } - .getOrElse(defaultValue) - } - - def nameReferenceKind(expr: KtNameReferenceExpression): NameReferenceKind = { - descriptorForNameReference(expr) - .collect { - case _: ValueDescriptor => NameReferenceKind.Property - case _: LazyClassDescriptor => NameReferenceKind.ClassName - case _: LazyJavaClassDescriptor => NameReferenceKind.ClassName - case _: DeserializedClassDescriptor => NameReferenceKind.ClassName - case _: EnumEntrySyntheticClassDescriptor => NameReferenceKind.EnumEntry - case unhandled: Any => - logger.debug( - s"Unhandled class in type info fetch in `nameReferenceKind[NameReference]` for `${expr.getText}` with class `${unhandled.getClass}`." - ) - NameReferenceKind.Unknown - } - .getOrElse(NameReferenceKind.Unknown) - } - - def typeFullName(expr: KtPrimaryConstructor | KtSecondaryConstructor, defaultValue: String): String = { - Option(bindingContext.get(BindingContext.CONSTRUCTOR, expr)) - .map { desc => typeRenderer.render(desc.getReturnType) } - .getOrElse(defaultValue) - } - - def typeFullName(expr: KtNameReferenceExpression, defaultValue: String): String = { - descriptorForNameReference(expr) - .flatMap { - case typedDesc: ValueDescriptor => Some(typeRenderer.render(typedDesc.getType)) - // TODO: add test cases for the LazyClassDescriptors (`okio` codebase serves as good example) - case typedDesc: LazyClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType)) - case typedDesc: LazyJavaClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType)) - case typedDesc: DeserializedClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType)) - case typedDesc: EnumEntrySyntheticClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType)) - case typedDesc: LazyPackageViewDescriptorImpl => Some(typeRenderer.renderFqNameForDesc(typedDesc)) - case unhandled: Any => - logger.debug(s"Unhandled class type info fetch in for `${expr.getText}` with class `${unhandled.getClass}`.") - None - case null => None - } - .getOrElse(defaultValue) - } - def typeFromImports(name: String, file: KtFile): Option[String] = { - file.getImportList.getImports.asScala.flatMap { directive => - if (directive.getImportedName != null && directive.getImportedName.toString == name.stripSuffix("?")) - Some(directive.getImportPath.getPathStr) - else None - }.headOption - } - - def implicitParameterName(expr: KtLambdaExpression): Option[String] = { - if (!expr.getValueParameters.isEmpty) { - None - } else { - val hasSingleImplicitParameter = - Option(bindingContext.get(BindingContext.EXPECTED_EXPRESSION_TYPE, expr)).exists { desc => - // 1 for the parameter + 1 for the return type == 2 - desc.getConstructor.getParameters.size() == 2 - } - val containingQualifiedExpression = Option(expr.getParent) - .map(_.getParent) - .flatMap(_.getParent match { - case q: KtQualifiedExpression => Some(q) - case _ => None - }) - containingQualifiedExpression match { - case Some(qualifiedExpression) => - resolvedCallDescriptor(qualifiedExpression) match { - case Some(fnDescriptor) => - val originalDesc = fnDescriptor.getOriginal - val vps = originalDesc.getValueParameters - val renderedFqName = typeRenderer.renderFqNameForDesc(originalDesc) - if ( - hasSingleImplicitParameter && - (renderedFqName.startsWith(TypeConstants.kotlinRunPrefix) || - renderedFqName.startsWith(TypeConstants.kotlinApplyPrefix)) - ) { - Some(TypeConstants.scopeFunctionThisParameterName) - // https://kotlinlang.org/docs/lambdas.html#it-implicit-name-of-a-single-parameter - } else if (hasSingleImplicitParameter) { - Some(TypeConstants.lambdaImplicitParameterName) - } else None - case None => None - } - case None => None - } - } - } -} - -object DefaultTypeInfoProvider { - private val logger = LoggerFactory.getLogger(getClass) - - private def bindingsForEntity(bindings: BindingContext, entity: KtElement): KeyFMap = { - try { - val thisField = bindings.getClass.getDeclaredField("this$0") - thisField.setAccessible(true) - val bindingTrace = thisField.get(bindings).asInstanceOf[NoScopeRecordCliBindingTrace] - - val mapField = bindingTrace.getClass.getSuperclass.getSuperclass.getDeclaredField("map") - mapField.setAccessible(true) - val map = mapField.get(bindingTrace) - - val mapMapField = map.getClass.getDeclaredField("map") - mapMapField.setAccessible(true) - val mapMap = mapMapField.get(map).asInstanceOf[java.util.Map[Object, KeyFMap]] - - val mapForEntity = mapMap.get(entity) - mapForEntity - } catch { - case noSuchField: NoSuchFieldException => - logger.debug( - s"Encountered _no such field_ exception while retrieving type info for `${entity.getName}`: `$noSuchField`." - ) - KeyFMap.EMPTY_MAP - case e if NonFatal(e) => - logger.debug(s"Encountered general exception while retrieving type info for `${entity.getName}`: `$e`.") - KeyFMap.EMPTY_MAP - } - } - - private def bindingsForEntityAsString(bindings: BindingContext, entity: KtElement): String = { - val mapForEntity = bindingsForEntity(bindings, entity) - if (mapForEntity != null) { - val keys = mapForEntity.getKeys - entity.toString + ": " + entity.getText + "\n" + - keys.map(key => s"$key: ${mapForEntity.get(key)}").mkString(" ", "\n ", "") - } else { - "No entries" - } - } - - def printBindingsForEntity(bindings: BindingContext, entity: KtElement): Unit = { - println(bindingsForEntityAsString(bindings, entity)) - } -} diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameReferenceKind.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameReferenceKind.scala deleted file mode 100644 index 8707f1e78537..000000000000 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameReferenceKind.scala +++ /dev/null @@ -1,5 +0,0 @@ -package io.joern.kotlin2cpg.types - -enum NameReferenceKind { - case Unknown, ClassName, EnumEntry, LocalVariable, Property -} diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameRenderer.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameRenderer.scala new file mode 100644 index 000000000000..a78510285bbb --- /dev/null +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/NameRenderer.scala @@ -0,0 +1,209 @@ +package io.joern.kotlin2cpg.types + +import io.joern.kotlin2cpg.types.NameRenderer.{BuiltinTypeTranslationTable, logger} +import io.joern.x2cpg.Defines +import org.jetbrains.kotlin.builtins.jvm.JavaToKotlinClassMap +import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptor +import org.jetbrains.kotlin.descriptors.{ + ClassDescriptor, + ConstructorDescriptor, + DeclarationDescriptor, + FunctionDescriptor, + ModuleDescriptor, + PackageFragmentDescriptor, + TypeParameterDescriptor +} +import org.jetbrains.kotlin.name.FqNameUnsafe +import org.jetbrains.kotlin.types.KotlinType +import org.jetbrains.kotlin.types.error.ErrorClassDescriptor +import org.slf4j.LoggerFactory + +import scala.collection.mutable +import scala.jdk.CollectionConverters.* + +object NameRenderer { + private val logger = LoggerFactory.getLogger(getClass) + + private val BuiltinTypeTranslationTable = mutable.HashMap( + "kotlin.Unit" -> "void", + "kotlin.Boolean" -> "boolean", + "kotlin.Char" -> "char", + "kotlin.Byte" -> "byte", + "kotlin.Short" -> "short", + "kotlin.Int" -> "int", + "kotlin.Float" -> "float", + "kotlin.Long" -> "long", + "kotlin.Double" -> "double", + "kotlin.BooleanArray" -> "boolean[]", + "kotlin.CharArray" -> "char[]", + "kotlin.ByteArray" -> "byte[]", + "kotlin.ShortArray" -> "short[]", + "kotlin.IntArray" -> "int[]", + "kotlin.FloatArray" -> "float[]", + "kotlin.LongArray" -> "long[]", + "kotlin.DoubleArray" -> "double[]" + ) +} + +class NameRenderer { + private val anonDescriptorToIndex = mutable.HashMap.empty[DeclarationDescriptor, Int] + private var anonObjectCounter = 0 + + def descName(desc: DeclarationDescriptor): String = { + if (desc.getName.isSpecial) { + desc match { + case _: ConstructorDescriptor => + Defines.ConstructorMethodName + case functionDesc: FunctionDescriptor => + Defines.ClosurePrefix + getAnonDescIndex(functionDesc) + case _ => + "object$" + getAnonDescIndex(desc) + } + } else { + desc.getName.getIdentifier + } + } + + def descFullName(desc: DeclarationDescriptor): Option[String] = { + val dealiasedDesc = desc match { + case typeAliasDesc: TypeAliasConstructorDescriptor => typeAliasDesc.getUnderlyingConstructorDescriptor + case _ => desc + } + descFullNameInternal(dealiasedDesc).map(_.reverse.mkString("")) + } + + def funcDescSignature(functionDesc: FunctionDescriptor): Option[String] = { + val originalDesc = functionDesc.getOriginal + val extRecvDesc = Option(originalDesc.getExtensionReceiverParameter) + val extRecvTypeFullName = extRecvDesc.flatMap(paramDesc => typeFullName(paramDesc.getType)) + + if (extRecvDesc.nonEmpty && extRecvTypeFullName.isEmpty) { return None } + + val paramTypeFullNames = originalDesc.getValueParameters.asScala.map(paramDesc => typeFullName(paramDesc.getType)) + if (paramTypeFullNames.exists(_.isEmpty)) { return None } + + val returnTypeFullName = if (isConstructorDesc(originalDesc)) { Some("void") } + else { typeFullName(originalDesc.getReturnType) } + + if (returnTypeFullName.isEmpty) { return None } + + val combinedParamTypeFn = paramTypeFullNames.prepended(extRecvTypeFullName) + val signature = s"${returnTypeFullName.get}(${combinedParamTypeFn.flatten.mkString(",")})" + Some(signature) + } + + def combineFunctionFullName(descFullName: String, signature: String): String = { + s"$descFullName:$signature" + } + + def typeFullName(typ: KotlinType): Option[String] = { + val javaFullName = + typ.getConstructor.getDeclarationDescriptor match { + case classDesc: ClassDescriptor => + val kotlinFullName = descFullName(classDesc) + if (kotlinFullName.contains("kotlin.Array")) { + val elementTypeFullName = typeFullName(typ.getArguments.get(0).getType) + elementTypeFullName.map(_ + "[]") + } else { + kotlinFullName.map(typeFullNameKotlinToJava) + } + case typeParamDesc: TypeParameterDescriptor => + val upperBoundTypeFns = typeParamDesc.getUpperBounds.asScala.map(typeFullName) + if (upperBoundTypeFns.exists(_.isEmpty)) { + None + } else { + Some(upperBoundTypeFns.flatten.mkString("&")) + } + case null => + // We do not expect this because to my understanding a typ should always have a constructor + // descriptor. + logger.warn( + s"Found type without constructor descriptor. Typ: $typ Constructor class: ${typ.getConstructor.getClass}" + ) + None + } + + javaFullName + } + + private def typeFullNameKotlinToJava(kotlinFullName: String): String = { + val javaFullName = BuiltinTypeTranslationTable.get(kotlinFullName) + if (javaFullName.isDefined) { + javaFullName.get + } else { + // Nested class fullnames contain '$' in our representation which need to be mapped to '.' + // in order to make use of JavaToKotlinClassMap. + val kotlinFullNameDotOnly = kotlinFullName.replace('$', '.') + val javaFullName = JavaToKotlinClassMap.INSTANCE.mapKotlinToJava(FqNameUnsafe(kotlinFullNameDotOnly)) + + val result = + if (javaFullName != null) { + // In front of nested class sub names we find '.' which needs to be mapped to '$' in our representation. + // After that we can map the normal name separator '/' to '.'. + javaFullName.toString.replace('.', '$').replace('/', '.') + } else { + kotlinFullName + } + result + } + } + + private def getAnonDescIndex(desc: DeclarationDescriptor): Int = { + anonDescriptorToIndex.getOrElseUpdate( + desc, { + val index = anonObjectCounter + anonObjectCounter += 1 + index + } + ) + } + + private def descFullNameInternal(desc: DeclarationDescriptor): Option[List[String]] = { + if (desc.isInstanceOf[ErrorClassDescriptor]) { + return None + } + val parentDesc = desc.getContainingDeclaration + + val parentFnParts = + Option(parentDesc) match { + case None => + Some(Nil) + case Some(parentDesc) => + descFullNameInternal(parentDesc) + } + + if (parentFnParts.isEmpty) { + return None + } + + var extendedFnParts = parentFnParts.get + desc match { + case packageFragmentDesc: PackageFragmentDescriptor => + if (!packageFragmentDesc.getName.isSpecial) { + extendedFnParts = packageFragmentDesc.getFqName.toString :: extendedFnParts + } + case _: ModuleDescriptor => // Do nothing since this is just the root element which has no namespace representation + case _ => + if (extendedFnParts.nonEmpty) { + val separator = if (parentDesc.isInstanceOf[ClassDescriptor] && desc.isInstanceOf[ClassDescriptor]) { + // Nested class + "$" + } else { + "." + } + extendedFnParts = separator :: extendedFnParts + } + val name = descName(desc) + extendedFnParts = name :: extendedFnParts + } + Some(extendedFnParts) + } + + private def isConstructorDesc(functionDesc: FunctionDescriptor): Boolean = { + functionDesc match { + case _: ConstructorDescriptor => true + case _ => false + } + } + +} diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeConstants.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeConstants.scala index f2417979c22a..a01224cec62b 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeConstants.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeConstants.scala @@ -1,22 +1,10 @@ package io.joern.kotlin2cpg.types object TypeConstants { - val any = "ANY" - val classLiteralReplacementMethodName = "getClass" - val initPrefix = io.joern.x2cpg.Defines.ConstructorMethodName - val kotlinFunctionXPrefix = "kotlin.Function" - val kotlinSuspendFunctionXPrefix = "kotlin.coroutines.SuspendFunction" - val kotlinAlsoPrefix = "kotlin.also" - val kotlinApplyPrefix = "kotlin.apply" - val kotlinRunPrefix = "kotlin.run" - val lambdaImplicitParameterName = "it" - val scopeFunctionThisParameterName = "this" - val kotlinUnit = "kotlin.Unit" - val javaLangBoolean = "boolean" - val javaLangClass = "java.lang.Class" - val javaLangObject = "java.lang.Object" - val javaLangString = "java.lang.String" - val kotlin = "kotlin" - val tType = "T" - val void = "void" + val Any = "ANY" + val KotlinFunctionPrefix = "kotlin.Function" + val JavaLangBoolean = "boolean" + val JavaLangObject = "java.lang.Object" + val Kotlin = "kotlin" + val Void = "void" } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeInfoProvider.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeInfoProvider.scala index 7f1a21abcf7b..52d8c4c1f04a 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeInfoProvider.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeInfoProvider.scala @@ -1,129 +1,246 @@ package io.joern.kotlin2cpg.types -import org.jetbrains.kotlin.descriptors.DescriptorVisibility -import org.jetbrains.kotlin.descriptors.Modality +import kotlin.reflect.jvm.internal.impl.load.java.descriptors.JavaClassConstructorDescriptor +import org.jetbrains.kotlin.cli.jvm.compiler.NoScopeRecordCliBindingTrace +import org.jetbrains.kotlin.com.intellij.util.keyFMap.KeyFMap +import org.jetbrains.kotlin.descriptors.DeclarationDescriptor +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.descriptors.impl.ClassConstructorDescriptorImpl +import org.jetbrains.kotlin.descriptors.impl.TypeAliasConstructorDescriptorImpl +import org.jetbrains.kotlin.descriptors.CallableDescriptor +import org.jetbrains.kotlin.descriptors.PropertyDescriptor +import org.jetbrains.kotlin.load.java.`lazy`.descriptors.LazyJavaClassDescriptor +import org.jetbrains.kotlin.load.java.sources.JavaSourceElement +import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaMethod import org.jetbrains.kotlin.psi.{ - KtAnnotationEntry, + Call, + KtArrayAccessExpression, KtBinaryExpression, KtCallExpression, - KtClassLiteralExpression, - KtClassOrObject, - KtDestructuringDeclarationEntry, KtElement, KtExpression, - KtFile, - KtLambdaExpression, - KtNamedFunction, KtNameReferenceExpression, - KtParameter, - KtPrimaryConstructor, - KtProperty, KtQualifiedExpression, - KtSecondaryConstructor, - KtTypeAlias, - KtTypeReference + KtSuperExpression, + KtThisExpression +} +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.DescriptorUtils +import org.jetbrains.kotlin.resolve.`lazy`.descriptors.LazyClassDescriptor +import org.jetbrains.kotlin.util.slicedMap.ReadOnlySlice +import org.slf4j.LoggerFactory + +import scala.annotation.unused +import scala.util.control.NonFatal + +class TypeInfoProvider(val bindingContext: BindingContext) { + + import io.joern.kotlin2cpg.types.TypeInfoProvider.bindingsForEntity + + def usedAsExpression(expr: KtExpression): Option[Boolean] = { + val mapForEntity = bindingsForEntity(bindingContext, expr) + Option(mapForEntity.get(BindingContext.USED_AS_EXPRESSION.getKey)).map(_.booleanValue()) + } + + def usedAsImplicitThis(expr: KtNameReferenceExpression): Boolean = { + val mapForEntity = bindingsForEntity(bindingContext, expr) + val isCallExprWithTarget = Option(mapForEntity) + .map(_.getKeys) + .exists(ks => + ks.contains(BindingContext.CALL.getKey) + && ks.contains(BindingContext.USED_AS_EXPRESSION.getKey) + && ks.contains(BindingContext.REFERENCE_TARGET.getKey) + ) + isCallExprWithTarget && resolvedPropertyDescriptor(expr).exists { d => + d.getDispatchReceiverParameter != null && d.getDispatchReceiverParameter.getName.asString() == "" + } + } + + def isStaticMethodCall(expr: KtQualifiedExpression): Boolean = { + resolvedCallDescriptor(expr) + .map(_.getSource) + .collect { case s: JavaSourceElement => s } + .map(_.getJavaElement) + .collect { case bjm: BinaryJavaMethod => bjm } + .exists(_.isStatic) + } + + def isRefToCompanionObject(expr: KtNameReferenceExpression): Boolean = { + val mapForEntity = bindingsForEntity(bindingContext, expr) + Option(mapForEntity) + .map(_.getKeys) + .exists(_.contains(BindingContext.SHORT_REFERENCE_TO_COMPANION_OBJECT.getKey)) + } + + private def subexpressionForResolvedCallInfo(expr: KtExpression): KtExpression = { + expr match { + case typedExpr: KtCallExpression => + Option(typedExpr.getFirstChild) + .collect { case expr: KtExpression => expr } + .getOrElse(expr) + case typedExpr: KtQualifiedExpression => + Option(typedExpr.getSelectorExpression) + .collect { case expr: KtCallExpression => expr } + .map(subexpressionForResolvedCallInfo) + .getOrElse(typedExpr) + case typedExpr: KtBinaryExpression => + Option(typedExpr.getChildren.toList(1)) + .collect { case expr: KtExpression => expr } + .getOrElse(expr) + case _ => expr + } + } + + private def resolvedPropertyDescriptor(expr: KtNameReferenceExpression): Option[PropertyDescriptor] = { + val descMaybe = for { + callForSubexpression <- Option(bindingContext.get(BindingContext.REFERENCE_TARGET, expr)) + desc = callForSubexpression + } yield desc + descMaybe.collect { case desc: PropertyDescriptor => desc } + } + + private def resolvedCallDescriptor(expr: KtExpression): Option[FunctionDescriptor] = { + val relevantSubexpression = subexpressionForResolvedCallInfo(expr) + val descMaybe = for { + callForSubexpression <- Option(bindingContext.get(BindingContext.CALL, relevantSubexpression)) + resolvedCallForSubexpression <- Option(bindingContext.get(BindingContext.RESOLVED_CALL, callForSubexpression)) + desc = resolvedCallForSubexpression.getResultingDescriptor + } yield desc + descMaybe.collect { case desc: FunctionDescriptor => desc } + } + + private def isConstructorDescriptor(desc: FunctionDescriptor): Boolean = { + desc match { + case _: JavaClassConstructorDescriptor => true + case _: ClassConstructorDescriptorImpl => true + case _: TypeAliasConstructorDescriptorImpl => true + case _ => false + } + } + + def isConstructorCall(expr: KtExpression): Option[Boolean] = { + expr match { + case _: KtCallExpression | _: KtQualifiedExpression => + resolvedCallDescriptor(expr) match { + case Some(desc) if isConstructorDescriptor(desc) => Some(true) + case _ => Some(false) + } + case _ => Some(false) + } + } + + private def hasStaticDesc(expr: KtQualifiedExpression): Boolean = { + resolvedCallDescriptor(expr).forall(_.getDispatchReceiverParameter == null) + } + + def bindingKind(expr: KtQualifiedExpression): CallKind = { + val isStaticBasedOnStructure = expr.getReceiverExpression.isInstanceOf[KtSuperExpression] + if (isStaticBasedOnStructure) return CallKind.StaticCall + + val isDynamicBasedOnStructure = expr.getReceiverExpression match { + case _: KtArrayAccessExpression => true + case _: KtThisExpression => true + case _ => false + } + if (isDynamicBasedOnStructure) return CallKind.DynamicCall + + resolvedCallDescriptor(expr) + .map { desc => + val isExtension = DescriptorUtils.isExtension(desc) + val isStatic = DescriptorUtils.isStaticDeclaration(desc) || hasStaticDesc(expr) + + if (isExtension) CallKind.ExtensionCall + else if (isStatic) CallKind.StaticCall + else CallKind.DynamicCall + } + .getOrElse(CallKind.Unknown) + } + + def isReferenceToClass(expr: KtNameReferenceExpression): Boolean = { + descriptorForNameReference(expr).exists { + case _: LazyJavaClassDescriptor => true + case _: LazyClassDescriptor => true + case _ => false + } + } + + private def descriptorForNameReference(expr: KtNameReferenceExpression): Option[DeclarationDescriptor] = { + Option(bindingsForEntity(bindingContext, expr)) + .map(_ => bindingContext.get(BindingContext.REFERENCE_TARGET, expr)) + } } -case class AnonymousObjectContext(declaration: KtElement) - -trait TypeInfoProvider(val typeRenderer: TypeRenderer = new TypeRenderer()) { - def isExtensionFn(fn: KtNamedFunction): Boolean - - def usedAsExpression(expr: KtExpression): Option[Boolean] - - def containingTypeDeclFullName(ktFn: KtNamedFunction, defaultValue: String): String - - def isStaticMethodCall(expr: KtQualifiedExpression): Boolean - - def visibility(fn: KtNamedFunction): Option[DescriptorVisibility] - - def modality(fn: KtNamedFunction): Option[Modality] - - def modality(ktClass: KtClassOrObject): Option[Modality] - - def returnType(elem: KtNamedFunction, defaultValue: String): String - - def containingDeclFullName(expr: KtCallExpression): Option[String] - - def containingDeclType(expr: KtQualifiedExpression, defaultValue: String): String - - def expressionType(expr: KtExpression, defaultValue: String): String - - def inheritanceTypes(expr: KtClassOrObject, or: Seq[String]): Seq[String] - - def parameterType(expr: KtParameter, defaultValue: String): String - - def destructuringEntryType(expr: KtDestructuringDeclarationEntry, defaultValue: String): String - - def propertyType(expr: KtProperty, defaultValue: String): String - - def fullName(expr: KtClassOrObject, defaultValue: String, ctx: Option[AnonymousObjectContext] = None): String - - def fullName(expr: KtTypeAlias, defaultValue: String): String - - def fullNameWithSignature(expr: KtDestructuringDeclarationEntry, defaultValue: (String, String)): (String, String) - - def aliasTypeFullName(expr: KtTypeAlias, defaultValue: String): String - - def typeFullName(expr: KtNameReferenceExpression, defaultValue: String): String - - def referenceTargetTypeFullName(expr: KtNameReferenceExpression, defaultValue: String): String - - def typeFullName(expr: KtBinaryExpression, defaultValue: String): String - - def typeFullName(expr: KtAnnotationEntry, defaultValue: String): String - - def isReferenceToClass(expr: KtNameReferenceExpression): Boolean - - def bindingKind(expr: KtQualifiedExpression): CallKind - - def fullNameWithSignature(expr: KtQualifiedExpression, or: (String, String)): (String, String) - - def fullNameWithSignature(call: KtCallExpression, or: (String, String)): (String, String) - - def fullNameWithSignature(expr: KtPrimaryConstructor, or: (String, String)): (String, String) - - def fullNameWithSignature(expr: KtSecondaryConstructor, or: (String, String)): (String, String) - - def fullNameWithSignature(call: KtBinaryExpression, or: (String, String)): (String, String) - - def fullNameWithSignature(expr: KtNamedFunction, or: (String, String)): (String, String) - - def fullNameWithSignatureAsLambda(expr: KtNamedFunction, lambdaName: String): (String, String) - - def fullNameWithSignature(expr: KtClassLiteralExpression, or: (String, String)): (String, String) - - def fullNameWithSignature(expr: KtLambdaExpression, lambdaName: String): (String, String) - - def anySignature(args: Seq[Any]): String - - def returnTypeFullName(expr: KtLambdaExpression): String - - def hasApplyOrAlsoScopeFunctionParent(expr: KtLambdaExpression): Boolean - - def nameReferenceKind(expr: KtNameReferenceExpression): NameReferenceKind - - def isConstructorCall(expr: KtExpression): Option[Boolean] - - def typeFullName(expr: KtTypeReference, defaultValue: String): String - - def typeFullName(expr: KtPrimaryConstructor | KtSecondaryConstructor, defaultValue: String): String - - def typeFullName(expr: KtCallExpression, defaultValue: String): String - - def typeFullName(expr: KtParameter, defaultValue: String): String - - def typeFullName(expr: KtDestructuringDeclarationEntry, defaultValue: String): String - - def hasStaticDesc(expr: KtQualifiedExpression): Boolean - - def implicitParameterName(expr: KtLambdaExpression): Option[String] - - def isCompanionObject(expr: KtClassOrObject): Boolean - - def isRefToCompanionObject(expr: KtNameReferenceExpression): Boolean - - def typeFullName(expr: KtClassOrObject, defaultValue: String): String - - def typeFromImports(name: String, file: KtFile): Option[String] +object TypeInfoProvider { + private val logger = LoggerFactory.getLogger(getClass) + + /** For internal debugging purposes */ + @unused + def allBindingsOfKind[K, V](bindings: BindingContext, kind: ReadOnlySlice[K, V]): collection.Seq[(K, V)] = { + val thisField = bindings.getClass.getDeclaredField("this$0") + thisField.setAccessible(true) + val bindingTrace = thisField.get(bindings).asInstanceOf[NoScopeRecordCliBindingTrace] + + val mapField = bindingTrace.getClass.getSuperclass.getSuperclass.getDeclaredField("map") + mapField.setAccessible(true) + val map = mapField.get(bindingTrace) + + val mapMapField = map.getClass.getDeclaredField("map") + mapMapField.setAccessible(true) + val mapMap = mapMapField.get(map).asInstanceOf[java.util.Map[Object, KeyFMap]] + + val result = scala.collection.mutable.ArrayBuffer.empty[(K, V)] + + mapMap.forEach { (keyObject: Object, fMap: KeyFMap) => + val kindValue = fMap.get(kind.getKey) + if (kindValue != null) { + result.append((keyObject.asInstanceOf[K], kindValue)) + } + } + + result + } + + /** For internal debugging purposes */ + def bindingsForEntity(bindings: BindingContext, entity: KtElement | Call): KeyFMap = { + try { + val thisField = bindings.getClass.getDeclaredField("this$0") + thisField.setAccessible(true) + val bindingTrace = thisField.get(bindings).asInstanceOf[NoScopeRecordCliBindingTrace] + + val mapField = bindingTrace.getClass.getSuperclass.getSuperclass.getDeclaredField("map") + mapField.setAccessible(true) + val map = mapField.get(bindingTrace) + + val mapMapField = map.getClass.getDeclaredField("map") + mapMapField.setAccessible(true) + val mapMap = mapMapField.get(map).asInstanceOf[java.util.Map[Object, KeyFMap]] + + val mapForEntity = mapMap.get(entity) + mapForEntity + } catch { + case noSuchField: NoSuchFieldException => + logger.debug(s"Encountered _no such field_ exception while retrieving type info for `$entity`: `$noSuchField`.") + KeyFMap.EMPTY_MAP + case e if NonFatal(e) => + logger.debug(s"Encountered general exception while retrieving type info for `$entity`: `$e`.") + KeyFMap.EMPTY_MAP + } + } + + /** For internal debugging purposes */ + def bindingsForEntityAsString(bindings: BindingContext, entity: KtElement): String = { + val mapForEntity = bindingsForEntity(bindings, entity) + if (mapForEntity != null) { + val keys = mapForEntity.getKeys + entity.toString + ": " + entity.getText + "\n" + + keys.map(key => s"$key: ${mapForEntity.get(key)}").mkString(" ", "\n ", "") + } else { + "No entries" + } + } + + @unused + def printBindingsForEntity(bindings: BindingContext, entity: KtElement): Unit = { + println(bindingsForEntityAsString(bindings, entity)) + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeRenderer.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeRenderer.scala deleted file mode 100644 index 1660485f84e8..000000000000 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/types/TypeRenderer.scala +++ /dev/null @@ -1,200 +0,0 @@ -package io.joern.kotlin2cpg.types - -import io.joern.kotlin2cpg.psi.PsiUtils -import io.joern.x2cpg.Defines -import org.jetbrains.kotlin.descriptors.{ClassDescriptor, DeclarationDescriptor, SimpleFunctionDescriptor} -import org.jetbrains.kotlin.resolve.{DescriptorToSourceUtils, DescriptorUtils} -import org.jetbrains.kotlin.types.{ErrorUtils, KotlinType, TypeProjection, TypeUtils} -import org.jetbrains.kotlin.types.error.ErrorType -import org.jetbrains.kotlin.builtins.jvm.JavaToKotlinClassMap -import org.jetbrains.kotlin.name.FqName -import org.jetbrains.kotlin.renderer.{DescriptorRenderer, DescriptorRendererImpl, DescriptorRendererOptionsImpl} -import org.jetbrains.kotlin.types.typeUtil.TypeUtilsKt -import org.jetbrains.kotlin.resolve.jvm.JvmPrimitiveType - -import scala.jdk.CollectionConverters.* - -object TypeRenderer { - - private val cpgUnresolvedType = - ErrorUtils.createUnresolvedType(Defines.UnresolvedNamespace, new java.util.ArrayList[TypeProjection]()) - - val primitiveArrayMappings: Map[String, String] = Map[String, String]( - "kotlin.BooleanArray" -> "boolean[]", - "kotlin.ByteArray" -> "byte[]", - "kotlin.CharArray" -> "char[]", - "kotlin.DoubleArray" -> "double[]", - "kotlin.FloatArray" -> "float[]", - "kotlin.IntArray" -> "int[]", - "kotlin.LongArray" -> "long[]", - "kotlin.ShortArray" -> "short[]" - ) - -} - -class TypeRenderer(val keepTypeArguments: Boolean = false) { - - import TypeRenderer.* - - private def descriptorRenderer(): DescriptorRenderer = { - val opts = new DescriptorRendererOptionsImpl - opts.setParameterNamesInFunctionalTypes(false) - opts.setInformativeErrorType(false) - opts.setTypeNormalizer { - case _: ErrorType => cpgUnresolvedType - case t => t - } - new DescriptorRendererImpl(opts) - } - - def renderFqNameForDesc(desc: DeclarationDescriptor): String = { - val renderer = descriptorRenderer() - val fqName = DescriptorUtils.getFqName(desc) - val simpleRender = stripped(renderer.renderFqName(fqName)) - def maybeReplacedOrTake(c: DeclarationDescriptor, or: String): String = { - c match { - case tc: ClassDescriptor if DescriptorUtils.isCompanionObject(tc) || tc.isInner => - val rendered = stripped(renderer.renderFqName(fqName)) - rendered.replaceFirst("\\." + c.getName, "\\$" + c.getName) - case tc: ClassDescriptor if DescriptorUtils.isAnonymousObject(tc) => - val rendered = stripped(renderer.renderFqName(fqName)) - - val psiElement = DescriptorToSourceUtils.getSourceFromDescriptor(tc) - val psiContainingDecl = DescriptorToSourceUtils.getSourceFromDescriptor(tc.getContainingDeclaration) - val objectIdx = - PsiUtils - .objectIdxMaybe(psiElement, psiContainingDecl) - .getOrElse("nan") - val out = rendered.replaceFirst("\\.$", "\\$object\\$" + s"$objectIdx") - out - case _ => or - } - } - val strippedOfContainingDeclarationIfNeeded = - Option(desc.getContainingDeclaration) - .map { - case c: ClassDescriptor => maybeReplacedOrTake(c, simpleRender) - case _ => simpleRender - } - .getOrElse(simpleRender) - desc match { - case c: ClassDescriptor => maybeReplacedOrTake(c, strippedOfContainingDeclarationIfNeeded) - case _ => strippedOfContainingDeclarationIfNeeded - } - } - - private def maybeUnwrappedRender(render: String, unwrapPrimitives: Boolean, fqName: FqName) = { - val isWrapperOfPrimitiveType = JvmPrimitiveType.isWrapperClassName(fqName) - if (unwrapPrimitives && isWrapperOfPrimitiveType) { - JvmPrimitiveType - .values() - .toList - .filter(_.getWrapperFqName.toString == fqName.toString) - .map(_.getJavaKeywordName) - .head - } else render - } - - private def renderForDescriptor(descriptor: ClassDescriptor, unwrapPrimitives: Boolean, t: KotlinType): String = { - val renderer = descriptorRenderer() - val fqName = DescriptorUtils.getFqName(descriptor) - Option(JavaToKotlinClassMap.INSTANCE.mapKotlinToJava(fqName)) - .map { mappedType => - val fqName = mappedType.asSingleFqName() - val render = stripped(renderer.renderFqName(fqName.toUnsafe)) - maybeUnwrappedRender(render, unwrapPrimitives, fqName) - } - .getOrElse { - if (DescriptorUtils.isCompanionObject(descriptor) || descriptor.isInner) { - val rendered = stripped(renderer.renderFqName(fqName)) - val companionObjectName = descriptor.getName - // replaces `apkg.ContainingClass.CompanionObjectName` with `apkg.ContainingClass$CompanionObjectName` - rendered.replaceFirst("\\." + companionObjectName, "\\$" + companionObjectName) - } else { - descriptor.getContainingDeclaration match { - case fn: SimpleFunctionDescriptor => - val renderedFqName = stripped(renderer.renderFqName(DescriptorUtils.getFqName(descriptor))) - val containingDescName = fn.getName - // replaces `apkg.containingMethodName.className` with `apkg.className$containingMethodName` - renderedFqName.replaceFirst("\\." + containingDescName + "\\.([^.]+)", ".$1" + "\\$" + containingDescName) - case _ => stripped(renderer.renderType(t)) - } - } - } - } - - def render(t: KotlinType, shouldMapPrimitiveArrayTypes: Boolean = true, unwrapPrimitives: Boolean = true): String = { - val rendered = - if (t.isInstanceOf[ErrorType]) TypeConstants.any - else if (TypeUtilsKt.isTypeParameter(t)) TypeConstants.javaLangObject - else if (isFunctionXType(t)) TypeConstants.kotlinFunctionXPrefix + (t.getArguments.size() - 1).toString - else - Option(TypeUtils.getClassDescriptor(t)) - .map { descriptor => - renderForDescriptor(descriptor, unwrapPrimitives, t) - } - .getOrElse { - val renderer = descriptorRenderer() - val relevantT = Option(TypeUtilsKt.getImmediateSuperclassNotAny(t)).getOrElse(t) - stripped(renderer.renderType(relevantT)) - } - val renderedType = - if (shouldMapPrimitiveArrayTypes && primitiveArrayMappings.contains(rendered)) primitiveArrayMappings(rendered) - else if (rendered == TypeConstants.kotlinUnit) TypeConstants.void - else rendered - - if (keepTypeArguments && !t.getArguments.isEmpty) { - val typeArgs = t.getArguments.asScala - .map(_.getType) - .map(render(_, shouldMapPrimitiveArrayTypes, unwrapPrimitives)) - .mkString(",") - s"$renderedType<$typeArgs>" - } else { - renderedType - } - } - - private def isFunctionXType(t: KotlinType): Boolean = { - val renderer = descriptorRenderer() - val renderedConstructor = renderer.renderTypeConstructor(t.getConstructor) - renderedConstructor.startsWith(TypeConstants.kotlinFunctionXPrefix) || - renderedConstructor.startsWith(TypeConstants.kotlinSuspendFunctionXPrefix) - } - - def stripped(typeName: String): String = { - def stripTypeParams(typeName: String): String = { - // (? when it is right at the beginning. - // We do this because at the beginning of a type name we cannot - // have type parameters but instead which - // we do not want to strip. - // Sometimes with a lambda expression as an argument, we see a - // named function instead of a lambda, so we keep the - // tag. - typeName.replaceAll("(?", "") - } - def stripOut(name: String): String = { - if (name.contains("<") && name.contains(">") && name.contains("out")) { - name.replaceAll("(<[^o]*)[(]?out[)]?[ ]*([a-zA-Z])", "<$2") - } else { - name - } - } - def stripOptionality(typeName: String): String = { - typeName.replaceAll("!", "").replaceAll("\\?", "") - } - def stripDebugInfo(typeName: String): String = { - if (typeName.contains("/* =")) { - typeName.split("/\\* =")(0) - } else { - typeName - } - } - - val t1 = stripOut(typeName) - val t2 = stripDebugInfo(t1) - val t3 = stripOptionality(t2) - val t4 = stripTypeParams(t3) - t4.trim().replaceAll(" ", "").replaceAll("`", "") - } -} diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/CompilerAPITests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/CompilerAPITests.scala index d7c78defaa52..e8dd77a86bb6 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/CompilerAPITests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/CompilerAPITests.scala @@ -77,8 +77,9 @@ class CompilerAPITests extends AnyFreeSpec with Matchers { "should not contain methods with unresolved types/namespaces" in { val command = - if (scala.util.Properties.isWin) "cmd.exe /C gradlew.bat gatherDependencies" else "./gradlew gatherDependencies" - ExternalCommand.run(command, projectDirPath) shouldBe Symbol("success") + if (scala.util.Properties.isWin) Seq("cmd.exe", "/C", "gradlew.bat", "gatherDependencies") + else Seq("./gradlew", "gatherDependencies") + ExternalCommand.run(command, projectDirPath).toTry shouldBe Symbol("success") val config = Config(classpath = Set(projectDependenciesPath.toString)) val cpg = new Kotlin2Cpg().createCpg(projectDirPath)(config).getOrElse { fail("Could not create a CPG!") diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/JavaInteroperabilityTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/JavaInteroperabilityTests.scala index 580b2056da24..55f047956f61 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/JavaInteroperabilityTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/compiler/JavaInteroperabilityTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.compiler import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class JavaInteroperabilityTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with Java interop" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/config/ConfigTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/config/ConfigTests.scala index 7aa57c3fa9e4..df70c46988e3 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/CollectionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/CollectionsTests.scala index c9bb114aba6a..baaa894378a8 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/CollectionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/CollectionsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CollectionsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ControlExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ControlExpressionsTests.scala index 83742fc2916a..99031a2c5510 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ControlExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ControlExpressionsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ControlExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/DestructuringTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/DestructuringTests.scala index 23d0a75e2e71..108ae3e4fab1 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/DestructuringTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/DestructuringTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DestructuringTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ExtensionFnsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ExtensionFnsTests.scala index ef42c65d7aac..d6df51061783 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ExtensionFnsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ExtensionFnsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ExtensionFnsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ForTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ForTests.scala index 5b52128e351f..a810b2f050f9 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ForTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ForTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ForTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/FunctionCallTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/FunctionCallTests.scala index bffc61e47053..f5e138aec00c 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/FunctionCallTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/FunctionCallTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class FunctionCallTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/GenericsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/GenericsTests.scala index 6510a41d156f..7c4528652d61 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/GenericsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/GenericsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class GenericsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/IfTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/IfTests.scala index 39b9a9a03f59..3c2b48e764b2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/IfTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/IfTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class IfTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/InterproceduralTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/InterproceduralTests.scala index 6d0d9ea0f607..474fd0b25863 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/InterproceduralTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/InterproceduralTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class InterproceduralTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/JavaInteroperabilityTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/JavaInteroperabilityTests.scala index 3219e96572f6..b6e998601691 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/JavaInteroperabilityTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/JavaInteroperabilityTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class JavaInteroperabilityTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/LambdaTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/LambdaTests.scala index 3946446c2657..7151500a7d9b 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/LambdaTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/LambdaTests.scala @@ -2,13 +2,14 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve "CPG for code containing a lambda with parameter destructuring" should { - val cpg = code("""|package mypkg + val cpg = code(""" + |package mypkg | |fun f1(p: String) { | val m = mapOf(p to 1, "two" to 2, "three" to 3) @@ -23,13 +24,11 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = true) { flows.map(flowToResultPairs).toSet shouldBe Set( List( - ("f1(p)", Some(3)), - ("p to 1", Some(4)), - ("mapOf(p to 1, \"two\" to 2, \"three\" to 3)", Some(4)), - ("val m = mapOf(p to 1, \"two\" to 2, \"three\" to 3)", Some(4)), - ("m.forEach { (k, v) -> println(k) }", Some(5)), - ("0(k, v)", Some(5)), - ("println(k)", Some(5)) + ("f1(p)", Some(4)), + ("tmp_1 = it", None), + ("tmp_1.component1()", Some(6)), + ("k = tmp_1.component1()", Some(6)), + ("println(k)", Some(6)) ) ) } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ObjectExpressionsAndDeclarationsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ObjectExpressionsAndDeclarationsTests.scala index 7ab56c2b3df4..f1a7e87f961d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ObjectExpressionsAndDeclarationsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ObjectExpressionsAndDeclarationsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ObjectExpressionsAndDeclarationsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/OperatorTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/OperatorTests.scala index 927c9dcf8069..8e095ce30420 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/OperatorTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/OperatorTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class OperatorTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ScopeFunctionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ScopeFunctionsTests.scala index 7e70895c5772..690d7eaec9bd 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ScopeFunctionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/ScopeFunctionsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ScopeFunctionsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/SimpleDataFlowTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/SimpleDataFlowTests.scala index ac54ea19e525..987bd4bd9c10 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/SimpleDataFlowTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/SimpleDataFlowTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SimpleDataFlowTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/TryTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/TryTests.scala index 7002f4da3b0e..d25e5e83c117 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/TryTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/TryTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TryTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhenTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhenTests.scala index 0c94cf520002..021fc3198813 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhenTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhenTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class WhenTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhileTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhileTests.scala index 74beb35f10fd..aa7e419eabda 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhileTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/dataflow/WhileTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.dataflow import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class WhileTests extends KotlinCode2CpgFixture(withOssDataflow = true) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/Kotlin2CpgHTTPServerTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/Kotlin2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..2ceaa95affa3 --- /dev/null +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/Kotlin2CpgHTTPServerTests.scala @@ -0,0 +1,84 @@ +package io.joern.kotlin2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class Kotlin2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("kotlin2cpgTestsHttpTest") + val file = dir / "main.kt" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |package mypkg + |fun main(args : Array) { + | println($indexStr) + |} + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.kotlin2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.kotlin2cpg.Main.stop() + } + + "Using kotlin2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("kotlin2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l shouldBe List("println()") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("kotlin2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l shouldBe List(s"println($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/SourceFilesPickerTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/SourceFilesPickerTests.scala similarity index 94% rename from joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/SourceFilesPickerTests.scala rename to joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/SourceFilesPickerTests.scala index 730b34aa9ee4..fd252128c740 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/SourceFilesPickerTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/io/SourceFilesPickerTests.scala @@ -1,7 +1,6 @@ -package io.joern.kotlin2cpg +package io.joern.kotlin2cpg.io import io.joern.kotlin2cpg.files.SourceFilesPicker - import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers import org.scalatest.BeforeAndAfterAll diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/postProcessing/TypeRecoveryPassTest.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/postProcessing/TypeRecoveryPassTest.scala index 6e241e7e2afd..bfbc183ffe73 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/postProcessing/TypeRecoveryPassTest.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/postProcessing/TypeRecoveryPassTest.scala @@ -1,8 +1,9 @@ package io.joern.kotlin2cpg.postProcessing import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.joern.x2cpg.Defines import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TypeRecoveryPassTest extends KotlinCode2CpgFixture(withPostProcessing = true) { @@ -27,7 +28,8 @@ class TypeRecoveryPassTest extends KotlinCode2CpgFixture(withPostProcessing = tr } "be able to faciliate methodFullName resolution for call made from identifier object" in { - cpg.call("getInstance").methodFullName.l shouldBe List("com.firebase.ui.auth.AuthUI.getInstance:ANY()") + cpg.call("getInstance").methodFullName.l shouldBe + List(s"com.firebase.ui.auth.AuthUI.getInstance:${Defines.UnresolvedSignature}(0)") } "be able to faciliate methodFullName resolution for call chaining" ignore { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala index 71d1ee17cc51..744e0aa100f5 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala @@ -1,8 +1,9 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.{Annotation, AnnotationLiteral} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AnnotationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with two identical calls, one annotated and one not" should { @@ -606,7 +607,7 @@ class AnnotationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { } } - "CPG for code with a custom annotation" should { + "CPG for code with a custom annotation and fallback handling" should { val cpg = code(""" |package mypkg |import retrofit2.http.POST @@ -618,12 +619,16 @@ class AnnotationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { |} |""".stripMargin) - "contain an ANNOTATION node" in { - cpg.all.collectAll[Annotation].codeExact("@POST(\"/name\")").size shouldBe 1 - } - "the ANNOTATION node should have correct full name" in { - cpg.all.collectAll[Annotation].codeExact("@POST(\"/name\")").fullName.head shouldBe "retrofit2.http.POST" + inside(cpg.annotation.codeExact("@POST(\"/name\")").l) { case List(annotation) => + annotation.fullName shouldBe "retrofit2.http.POST" + } + inside(cpg.annotation.code(".*Headers.*").l) { case List(annotation) => + annotation.fullName shouldBe s"${Defines.UnresolvedNamespace}.Headers" + } + inside(cpg.annotation.code(".*Body.*").l) { case List(annotation) => + annotation.fullName shouldBe s"${Defines.UnresolvedNamespace}.Body" + } } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnonymousFunctionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnonymousFunctionsTests.scala index fad29eeaf8b4..a370fe9a8f9d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnonymousFunctionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnonymousFunctionsTests.scala @@ -1,8 +1,9 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Ignore import scala.annotation.unused @@ -23,8 +24,8 @@ class AnonymousFunctionsTests extends KotlinCode2CpgFixture(withOssDataflow = fa "should contain a METHOD node for the anonymous fn with the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*0.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.foo.${Defines.ClosurePrefix}0:boolean(int)" + m.signature shouldBe "boolean(int)" } "should contain a METHOD node for the lambda with a corresponding METHOD_RETURN which has the correct props set" in { @@ -58,8 +59,8 @@ class AnonymousFunctionsTests extends KotlinCode2CpgFixture(withOssDataflow = fa "should contain a METHOD node for the anonymous fn with the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*0.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.foo.${Defines.ClosurePrefix}0:boolean(int)" + m.signature shouldBe "boolean(int)" } "should contain a METHOD node for the lambda with a corresponding METHOD_RETURN which has the correct props set" in { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArithmeticOperationsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArithmeticOperationsTests.scala index 1116622eaeaa..80f1d9695263 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArithmeticOperationsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArithmeticOperationsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ArithmeticOperationsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayAccessExprsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayAccessExprsTests.scala index 5a9b5767607c..84632d3f4acf 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayAccessExprsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayAccessExprsTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ArrayAccessExprsTests extends KotlinCode2CpgFixture(withOssDataflow = true) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayTypeNameTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayTypeNameTests.scala new file mode 100644 index 000000000000..1aef73852a49 --- /dev/null +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ArrayTypeNameTests.scala @@ -0,0 +1,56 @@ +package io.joern.kotlin2cpg.querying + +import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class ArrayTypeNameTests extends KotlinCode2CpgFixture(withOssDataflow = true) { + "test array type full name in method signature" in { + val cpg = code(""" + |package mypkg + |fun method(param: Array) {} + |""".stripMargin) + + val List(method) = cpg.method.name("method").l + method.signature shouldBe "void(java.lang.String[])" + } + + "test array type full name in method signature for nested array" in { + val cpg = code(""" + |package mypkg + |fun method(param: Array>) {} + |""".stripMargin) + + val List(method) = cpg.method.name("method").l + method.signature shouldBe "void(java.lang.String[][])" + } + + "test array type full name in method signature for builtin type array" in { + val cpg = code(""" + |package mypkg + |fun method(param: ByteArray) {} + |""".stripMargin) + + val List(method) = cpg.method.name("method").l + method.signature shouldBe "void(byte[])" + } + + "test array type full name in method signature for nested builtin type array" in { + val cpg = code(""" + |package mypkg + |fun method(param: Array) {} + |""".stripMargin) + + val List(method) = cpg.method.name("method").l + method.signature shouldBe "void(byte[][])" + } + + "test array type full name in method signature for array of kotlin type" in { + val cpg = code(""" + |package mypkg + |fun method(param: Array) {} + |""".stripMargin) + + val List(method) = cpg.method.name("method").l + method.signature shouldBe "void(boolean[])" + } +} diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AssignmentTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AssignmentTests.scala index 25ddaea2e3ef..4ef82a053da0 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AssignmentTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AssignmentTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AssignmentTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/BooleanLogicTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/BooleanLogicTests.scala index 3e6ed13c3b82..922a64926e35 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/BooleanLogicTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/BooleanLogicTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class BooleanLogicTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallGraphTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallGraphTests.scala index 794edfec1c3e..ddd91f3f1f80 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallGraphTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallGraphTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallGraphTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallTests.scala index a87bf5df6fc5..c81bd7a6626d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallTests.scala @@ -1,7 +1,9 @@ package io.joern.kotlin2cpg.querying +import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} +import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.semanticcpg.language.* @@ -228,7 +230,7 @@ class CallTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node `writeText` with the correct props set" in { val List(c) = cpg.call.code("f.writeText.*").l - c.methodFullName shouldBe "java.io.File.writeText:void(java.lang.String,java.nio.charset.Charset)" + c.methodFullName shouldBe "kotlin.io.writeText:void(java.io.File,java.lang.String,java.nio.charset.Charset)" } } @@ -258,11 +260,11 @@ class CallTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node for `MyCaseClass(\\\"AN_ARGUMENT\\\")` with the correct props set" in { val List(c) = cpg.call.code("MyCaseClass.*AN_ARGUMENT.*").l - c.methodFullName shouldBe "no.such.CaseClass:ANY(ANY)" + c.methodFullName shouldBe s"no.such.CaseClass:${Defines.UnresolvedSignature}(1)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.lineNumber shouldBe Some(10) c.columnNumber shouldBe Some(17) - c.signature shouldBe "ANY(ANY)" + c.signature shouldBe s"${Defines.UnresolvedSignature}(1)" } } @@ -319,6 +321,38 @@ class CallTests extends KotlinCode2CpgFixture(withOssDataflow = false) { } } + "CPG for code with named arguments in call on object" should { + val cpg = code(""" + |package no.such.pkg + |fun outer() { + | Pair(1,2).copy(second = 3) + |} + |""".stripMargin) + + "contain a CALL node with arguments that have the argument name set" in { + val List(c) = cpg.call.name("copy").l + c.argument(1).argumentName shouldBe Some("second") + } + } + + "CPG for code with implicit this access on apply and run call" should { + val cpg = code(""" + |package no.such.pkg + | + |fun outer() { + | Pair(1,2).apply { println(second) } + |} + |""".stripMargin) + + "contain a CALL node with argument that is a this access" in { + val List(printCall) = cpg.call.name("println").l + val secondCall = printCall.argument(1).asInstanceOf[Call] + secondCall.methodFullName shouldBe Operators.fieldAccess + secondCall.code shouldBe "this.second" + secondCall.argument(1).asInstanceOf[Identifier].typeFullName shouldBe "kotlin.Pair" + } + } + "CPG for code with call with argument with type with upper bound" should { val cpg = code(""" |package mypkg @@ -356,7 +390,7 @@ class CallTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node with correct METHOD_FULL_NAME set" in { val List(c) = cpg.method.nameExact("mapIndexedNotNullTo").callIn.l - c.methodFullName shouldBe "kotlin.sequences.Sequence.mapIndexedNotNullTo:java.lang.Object(java.util.Collection,kotlin.Function2)" + c.methodFullName shouldBe "kotlin.sequences.mapIndexedNotNullTo:java.util.Collection(kotlin.sequences.Sequence,java.util.Collection,kotlin.jvm.functions.Function2)" } } @@ -595,4 +629,108 @@ class CallTests extends KotlinCode2CpgFixture(withOssDataflow = false) { c.argument.map(_.argumentName).flatten.l shouldBe List("two", "one") } } + + "have correct call to overriden base class" in { + val cpg = code(""" + |package somePackage + |class A: java.io.Closeable { + | fun foo() { + | close() + | } + | override fun close() { + | } + |} + |""".stripMargin) + + inside(cpg.call.nameExact("close").l) { case List(call) => + call.methodFullName shouldBe "somePackage.A.close:void()" + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + inside(call.receiver.l) { case List(receiver: Identifier) => + receiver.name shouldBe Constants.ThisName + receiver.typeFullName shouldBe "somePackage.A" + } + inside(call.argument.l) { case List(argument: Identifier) => + argument.name shouldBe Constants.ThisName + argument.argumentIndex shouldBe 0 + } + } + } + + "have correct call to kotlin standard library function" in { + val cpg = code(""" + |fun method() { + | println("test") + |} + |""".stripMargin) + + inside(cpg.call.nameExact("println").l) { case List(call) => + call.methodFullName shouldBe "kotlin.io.println:void(java.lang.Object)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.receiver.isEmpty shouldBe true + inside(call.argument.l) { case List(argument: Literal) => + argument.code shouldBe "\"test\"" + argument.argumentIndex shouldBe 1 + } + } + } + + "have correct call to custom top level function" in { + val cpg = code(""" + |package somePackage + |fun topLevelFunc() { + |} + |fun method() { + | topLevelFunc() + |} + |""".stripMargin) + + inside(cpg.call.nameExact("topLevelFunc").l) { case List(call) => + call.methodFullName shouldBe "somePackage.topLevelFunc:void()" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + call.receiver.isEmpty shouldBe true + call.argument.isEmpty shouldBe true + } + } + + "have correct call to private class method" in { + val cpg = code(""" + |package somePackage + |class A { + | private fun func1() { + | } + | fun func2() { + | func1() + | } + |} + |""".stripMargin) + + inside(cpg.call.nameExact("func1").l) { case List(call) => + call.methodFullName shouldBe "somePackage.A.func1:void()" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + call.receiver.isEmpty shouldBe true + inside(call.argument.l) { case List(argument: Identifier) => + argument.name shouldBe "this" + argument.argumentIndex shouldBe 0 + } + } + } + + "have correct call for nested qualified expressions" in { + val cpg = code(""" + |package somePackage + |class A { + | private val sub: A?; + | fun func() { + | sub?.sub?.func(); + | } + |} + |""".stripMargin) + + inside(cpg.call.nameExact("func").l) { case List(call) => + call.methodFullName shouldBe "somePackage.A.func:void()" + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + } + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallableReferenceTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallableReferenceTests.scala index e5edd465e017..da0b64126080 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallableReferenceTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallableReferenceTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallableReferenceTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallbackTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallbackTests.scala index d6c2814bafc0..9b1a3eb9c00d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallbackTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallbackTests.scala @@ -2,7 +2,8 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, MethodRef} +import io.shiftleft.semanticcpg.language.* class CallbackTests extends KotlinCode2CpgFixture(withOssDataflow = false) { @@ -99,7 +100,12 @@ class CallbackTests extends KotlinCode2CpgFixture(withOssDataflow = false) { c.lineNumber shouldBe Some(10) c.columnNumber shouldBe Some(8) c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - c.argument.size shouldBe 1 + c.argument.size shouldBe 2 + inside(c.argument.l) { case List(arg1: Identifier, arg2: MethodRef) => + arg1.name shouldBe "this" + arg1.argumentIndex shouldBe 0 + arg2.argumentIndex shouldBe 1 + } } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToConstructorTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToConstructorTests.scala index 49335321c273..e8f1d3cb550a 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToConstructorTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToConstructorTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal, Local} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallsToConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { @@ -69,12 +69,12 @@ class CallsToConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = fa val List(qeCall) = cpg.call.methodFullName(".*writeText.*").l val List(callLhs: Block, callRhs: Literal) = qeCall.argument.l: @unchecked - callRhs.argumentIndex shouldBe 1 + callRhs.argumentIndex shouldBe 2 val loweredBlock = callLhs loweredBlock.typeFullName shouldBe "java.io.File" loweredBlock.code shouldBe "" - loweredBlock.argumentIndex shouldBe 0 + loweredBlock.argumentIndex shouldBe 1 val List(firstBlockChild: Local) = loweredBlock.astChildren.take(1).l: @unchecked firstBlockChild.name shouldBe "tmp_1" diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToFieldAccessTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToFieldAccessTests.scala index 95c3ee8c49e3..62bf91de9621 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToFieldAccessTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CallsToFieldAccessTests.scala @@ -26,6 +26,7 @@ class CallsToFieldAccessTests extends KotlinCode2CpgFixture(withOssDataflow = fa val List(c) = cpg.call.codeExact("println(x)").argument.isCall.l c.code shouldBe "this.x" c.name shouldBe Operators.fieldAccess + c.typeFullName shouldBe "java.lang.String" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.lineNumber shouldBe Some(6) c.columnNumber shouldBe Some(16) @@ -139,4 +140,21 @@ class CallsToFieldAccessTests extends KotlinCode2CpgFixture(withOssDataflow = fa x.methodFullName shouldBe "mypkg.AClass.printX:void()" } } + + "Field access after array/map access" should { + "have correct arguments" in { + val cpg = code(""" + |val m = LinkedHashMap() + |val x = m[1].aaa + |""".stripMargin) + + inside(cpg.call.methodFullNameExact(Operators.fieldAccess).argument.l) { case List(arg1, arg2) => + arg1.code shouldBe "m[1]" + arg1.argumentIndex shouldBe 1 + arg2.code shouldBe "aaa" + arg2.argumentIndex shouldBe 2 + } + } + } + } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CfgTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CfgTests.scala index 3ca374924968..8455d8aeb781 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CfgTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CfgTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CfgTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ClassLiteralTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ClassLiteralTests.scala index 0fe33a3b79ee..d0a01d0f68b9 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ClassLiteralTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ClassLiteralTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ClassLiteralTests extends KotlinCode2CpgFixture(withOssDataflow = false) { @@ -29,7 +29,7 @@ class ClassLiteralTests extends KotlinCode2CpgFixture(withOssDataflow = false) { c.lineNumber shouldBe Some(8) c.signature shouldBe "kotlin.reflect.KClass()" c.typeFullName shouldBe "kotlin.reflect.KClass" - c.methodFullName shouldBe "mypkg.Bar.getClass:kotlin.reflect.KClass()" + c.methodFullName shouldBe ".class" } "should contain a CALL node for the class literal expression inside dot-qualified expression" in { @@ -41,7 +41,7 @@ class ClassLiteralTests extends KotlinCode2CpgFixture(withOssDataflow = false) { c.lineNumber shouldBe Some(9) c.signature shouldBe "kotlin.reflect.KClass()" c.typeFullName shouldBe "kotlin.reflect.KClass" - c.methodFullName shouldBe "mypkg.Baz.getClass:kotlin.reflect.KClass()" + c.methodFullName shouldBe ".class" } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CollectionAccessTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CollectionAccessTests.scala index 3d30027b02e5..aaddb05e6db2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CollectionAccessTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CollectionAccessTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CollectionAccessTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CompanionObjectTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CompanionObjectTests.scala index bcb0948837db..09b52df6a5e6 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CompanionObjectTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/CompanionObjectTests.scala @@ -4,7 +4,7 @@ import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Member} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CompanionObjectTests extends KotlinCode2CpgFixture(withOssDataflow = false) { @@ -38,7 +38,7 @@ class CompanionObjectTests extends KotlinCode2CpgFixture(withOssDataflow = false val List(firstMember: Member, secondMember: Member) = td.member.l firstMember.name shouldBe "m" firstMember.typeFullName shouldBe "java.lang.String" - secondMember.name shouldBe Constants.companionObjectMemberName + secondMember.name shouldBe Constants.CompanionObjectMemberName secondMember.typeFullName shouldBe "mypkg.AClass" } @@ -64,7 +64,26 @@ class CompanionObjectTests extends KotlinCode2CpgFixture(withOssDataflow = false firstArg.argument.l: @unchecked firstArgOfLoweredCall.typeFullName shouldBe "mypkg.AClass$Companion" firstArgOfLoweredCall.refsTo.size shouldBe 0 // yes, 0. it's how the closed-source dataflow engine wants it atm - secondArgOfLoweredCall.canonicalName shouldBe Constants.companionObjectMemberName + secondArgOfLoweredCall.canonicalName shouldBe Constants.CompanionObjectMemberName + } + } + + "nested companion object and nested class test" in { + val cpg = code(""" + |package mypkg + | + |class AClass { + | companion object { + | class BClass { + | companion object NamedCompanion { + | } + | } + | } + |} + |""".stripMargin) + + inside(cpg.typeDecl.nameExact("NamedCompanion").l) { case List(typeDecl) => + typeDecl.fullName shouldBe "mypkg.AClass$Companion$BClass$NamedCompanion" } } @@ -98,7 +117,7 @@ class CompanionObjectTests extends KotlinCode2CpgFixture(withOssDataflow = false val List(firstMember: Member, secondMember: Member) = td.member.l firstMember.name shouldBe "m" firstMember.typeFullName shouldBe "java.lang.String" - secondMember.name shouldBe Constants.companionObjectMemberName + secondMember.name shouldBe Constants.CompanionObjectMemberName secondMember.typeFullName shouldBe "mypkg.AClass" } @@ -123,7 +142,7 @@ class CompanionObjectTests extends KotlinCode2CpgFixture(withOssDataflow = false firstArg.argument.l: @unchecked firstArgOfLoweredCall.typeFullName shouldBe "mypkg.AClass$NamedCompanion" firstArgOfLoweredCall.refsTo.size shouldBe 0 // yes, 0. it's how the closed-source dataflow engine wants it atm - secondArgOfLoweredCall.canonicalName shouldBe Constants.companionObjectMemberName + secondArgOfLoweredCall.canonicalName shouldBe Constants.CompanionObjectMemberName } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComparisonOperatorTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComparisonOperatorTests.scala index 9057c973301e..6fc007434046 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComparisonOperatorTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComparisonOperatorTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ComparisonOperatorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComplexExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComplexExpressionsTests.scala index 7e31ba806ead..b469e6f84ca1 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComplexExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ComplexExpressionsTests.scala @@ -3,8 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.edges.Argument -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.jIteratortoTraversal +import io.shiftleft.semanticcpg.language.* class ComplexExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with _and_/_or_ operator and try-catch as one of the arguments" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConfigFileTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConfigFileTests.scala index e0fdc9238657..99fca7000164 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConfigFileTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConfigFileTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ConfigFileTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConstructorTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConstructorTests.scala index ecf3222d003b..4eb3f338274c 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConstructorTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ConstructorTests.scala @@ -1,10 +1,10 @@ package io.joern.kotlin2cpg.querying -import io.joern.kotlin2cpg.Constants import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, MethodParameterIn} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { @@ -20,7 +20,7 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the constructor with the correct props set" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.Foo").method.l m.fullName shouldBe "mypkg.Foo.:void()" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.parameter.size shouldBe 1 Option(m.block).isDefined shouldBe true } @@ -36,7 +36,7 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the constructor with a block with no children" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.AClass").method.l m.fullName shouldBe "mypkg.AClass.:void(java.lang.String)" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.parameter.size shouldBe 2 Option(m.block).isDefined shouldBe true m.block.expressionDown.size shouldBe 0 @@ -64,7 +64,7 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the constructor with the correct props set" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.AClass").method.l m.fullName shouldBe "mypkg.AClass.:void(java.lang.String)" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.parameter.size shouldBe 2 Option(m.block).isDefined shouldBe true @@ -120,7 +120,7 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the constructor with the correct props set" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.Foo").method.l m.fullName shouldBe "mypkg.Foo.:void(java.lang.String)" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.parameter.size shouldBe 2 Option(m.block).isDefined shouldBe true } @@ -137,7 +137,7 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the constructor with the correct props set" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.Foo").method.l m.fullName shouldBe "mypkg.Foo.:void(java.lang.String)" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.parameter.size shouldBe 2 Option(m.block).isDefined shouldBe true } @@ -182,14 +182,14 @@ class ConstructorTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a METHOD node for the secondary constructor with properties set correctly" in { val List(m) = cpg.typeDecl.fullNameExact("mypkg.Foo").method.slice(1, 2).l m.fullName shouldBe "mypkg.Foo.:void(java.lang.String,int)" - m.name shouldBe io.joern.x2cpg.Defines.ConstructorMethodName + m.name shouldBe Defines.ConstructorMethodName m.lineNumber shouldBe Some(6) m.columnNumber shouldBe Some(4) m.methodReturn.typeFullName shouldBe "void" m.methodReturn.lineNumber shouldBe Some(6) m.methodReturn.columnNumber shouldBe Some(4) - m.block.astChildren.map(_.code).l shouldBe List(Constants.init, "this.bar = bar") + m.block.astChildren.map(_.code).l shouldBe List(Defines.ConstructorMethodName, "this.bar = bar") val List(mThisParam: MethodParameterIn, firstParam: MethodParameterIn, secondParam: MethodParameterIn) = m.parameter.l diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala index 4c14ced555f7..367c7e7f0fd3 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, ControlStructure, Identifier, Local} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ControlStructureTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple if-else" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DataClassTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DataClassTests.scala index b5bc553d1b4e..bee57af468ad 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DataClassTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DataClassTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DataClassTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple data class" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DefaultContentRootsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DefaultContentRootsTests.scala index c53f42ab2b96..ef8900fe0807 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DefaultContentRootsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DefaultContentRootsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DefaultContentRootsTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDefaultJars = true) { @@ -240,7 +240,7 @@ class DefaultContentRootsTests extends KotlinCode2CpgFixture(withOssDataflow = f "should contain a CALL node for `routes` with the correct methodFullName set" in { val List(c) = cpg.call.methodFullName(".*routes.*").l - c.methodFullName shouldBe "org.http4k.routing.routes:org.http4k.routing.RoutingHttpHandler(kotlin.Array)" + c.methodFullName shouldBe "org.http4k.routing.routes:org.http4k.routing.RoutingHttpHandler(org.http4k.routing.RoutingHttpHandler[])" } "should contain a CALL node for `req.query` with the correct methodFullName set" in { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DelegatedPropertiesTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DelegatedPropertiesTests.scala index 7fc0b7dc4649..9f53a65e7f21 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DelegatedPropertiesTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DelegatedPropertiesTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DelegatedPropertiesTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DestructuringTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DestructuringTests.scala index c0d046f962ed..96d03e1a589c 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DestructuringTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/DestructuringTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, Local} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DestructuringTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/EnumTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/EnumTests.scala index 8feb1ba5df25..3d4e6ad56dea 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/EnumTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/EnumTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class EnumTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ExtensionTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ExtensionTests.scala index 7c1e535826ea..3e3d7736da63 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ExtensionTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ExtensionTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ExtensionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple extension function declarations" should { @@ -11,25 +11,33 @@ class ExtensionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { val cpg = code(""" |package mypkg | - |class Example { - | fun printBar() { println("class.bar") } - |} + |class Example {} | - |fun Example.printBaz() { println("ext.baz") } + |fun Example.printBaz(text: String) { println(text) } | |fun main(args : Array) { - | Example().printBaz() + | Example().printBaz("ext.baz") |} |""".stripMargin) "should contain a CALL node for the calls to the extension fns with the correct MFN set" in { val List(c) = cpg.call.code(".*printBaz.*").l - c.methodFullName shouldBe "mypkg.Example.printBaz:void()" + c.methodFullName shouldBe "mypkg.printBaz:void(mypkg.Example,java.lang.String)" } "should contain a METHOD node for the extension fn with the correct MFN set" in { val List(m) = cpg.method.fullName(".*printBaz.*").l - m.fullName shouldBe "mypkg.Example.printBaz:void()" + m.fullName shouldBe "mypkg.printBaz:void(mypkg.Example,java.lang.String)" + } + + "should contain a METHOD node for the extension fn with the correct parameter indicies" in { + val x = cpg.method.fullName.l + inside(cpg.method.fullName(".*printBaz.*").parameter.l) { case List(thisParam, textParam) => + thisParam.index shouldBe 1 + thisParam.order shouldBe 1 + textParam.index shouldBe 2 + textParam.order shouldBe 2 + } } } @@ -73,23 +81,23 @@ class ExtensionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node with the correct props set for the call to the package-level defined extension fn" in { val List(c) = cpg.call.code("str.hash.*").where(_.method.fullName(".*main.*")).l - c.methodFullName shouldBe "java.lang.String.hash:java.lang.String()" + c.methodFullName shouldBe "mypkg.hash:java.lang.String(java.lang.String)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - c.signature shouldBe "java.lang.String()" + c.signature shouldBe "java.lang.String(java.lang.String)" } "should contain a CALL node with the correct props set for the call to the extension fn defined in `AClass`" in { val List(c) = cpg.typeDecl.fullName(".*AClass.*").method.fullName(".*hashStr.*").call.code("str.hash.*").l - c.methodFullName shouldBe "java.lang.String.hash:java.lang.String()" + c.methodFullName shouldBe "mypkg.AClass.hash:java.lang.String(java.lang.String)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - c.signature shouldBe "java.lang.String()" + c.signature shouldBe "java.lang.String(java.lang.String)" } "should contain a CALL node with the correct props set for the call to the extension fn defined in `BClass`" in { val List(c) = cpg.typeDecl.fullName(".*BClass.*").method.fullName(".*hashStr.*").call.code("str.hash.*").l - c.methodFullName shouldBe "java.lang.String.hash:java.lang.String()" + c.methodFullName shouldBe "mypkg.BClass.hash:java.lang.String(java.lang.String)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - c.signature shouldBe "java.lang.String()" + c.signature shouldBe "java.lang.String(java.lang.String)" } } @@ -97,19 +105,15 @@ class ExtensionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { val cpg = code(""" |package mypkg |fun f1(p: String) { - | val cs: CharSequence = "abcd" - | cs.onEach { println(it) } + | val cs: String = "abcd" + | cs.onEach { } |} |""".stripMargin) implicit val resolver = NoResolve "contain a CALL node with the correct METHOD_FULLNAME set" in { val List(c) = cpg.method.nameExact("onEach").callIn.l - // from the documentation at https://kotlinlang.org/api/latest/jvm/stdlib/kotlin.text/on-each.html - // ``` - // inline fun S.onEach(action: (Char) -> Unit): S - // ``` - c.methodFullName shouldBe "java.lang.CharSequence.onEach:java.lang.Object(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.text.onEach:java.lang.CharSequence(java.lang.CharSequence,kotlin.jvm.functions.Function1)" } } @@ -128,5 +132,4 @@ class ExtensionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { p1.typeFullName shouldBe "mypkg.AClass" } } - } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FieldAccessTests.scala index fd8521787396..df05efd93966 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FieldAccessTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{FieldIdentifier, Identifier} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FileTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FileTests.scala index 098c7f9329ba..4dae7d002d57 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FileTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/FileTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import java.io.File diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GenericsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GenericsTests.scala index d0a78b7ec837..55100727bac4 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GenericsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GenericsTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class GenericsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GlobalsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GlobalsTests.scala index ff11e7d51c86..8e3eed81a92f 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GlobalsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/GlobalsTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class GlobalsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code simple global declaration" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/IdentifierTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/IdentifierTests.scala index 0650814cb889..6c4c0745afa4 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/IdentifierTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/IdentifierTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class IdentifierTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with two simple methods" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ImportTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ImportTests.scala index 00b79d34b1ea..38bdb3aba291 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ImportTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ImportTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ImportTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/InnerClassesTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/InnerClassesTests.scala index 7bc06f15d41e..b81c62b036ca 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/InnerClassesTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/InnerClassesTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.Ignore class InnerClassesTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LabeledExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LabeledExpressionsTests.scala index b4243faf88c0..dc0fb0b566f9 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LabeledExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LabeledExpressionsTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LabeledExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LambdaTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LambdaTests.scala index 6de2838c017e..1c54c040b460 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LambdaTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LambdaTests.scala @@ -17,7 +17,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.MethodRef import io.shiftleft.codepropertygraph.generated.nodes.Return import io.shiftleft.codepropertygraph.generated.nodes.TypeDecl import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.jIteratortoTraversal class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDefaultJars = true) { "CPG for code with a simple lambda which captures a method parameter" should { @@ -42,9 +41,11 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef cb._refOut.size shouldBe 1 } - "should contain a CALL node with the signature of the lambda" in { + "should contain a CALL node for the `let` invocation" in { val List(c) = cpg.call.code("1.let.*").l - c.signature shouldBe "java.lang.Object(java.lang.Object)" + c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + c.methodFullName shouldBe "kotlin.let:java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" + c.signature shouldBe "java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" } } @@ -92,8 +93,8 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*0.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.f2.${Defines.ClosurePrefix}0:java.lang.String(java.lang.String)" + m.signature shouldBe "java.lang.String(java.lang.String)" } } @@ -118,8 +119,8 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.foo.${Defines.ClosurePrefix}0:void(java.lang.String)" + m.signature shouldBe "void(java.lang.String)" m.lineNumber shouldBe Some(6) m.columnNumber shouldBe Some(14) } @@ -127,7 +128,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda with a corresponding METHOD_RETURN which has the correct props set" in { val List(mr) = cpg.method.fullName(".*lambda.*").methodReturn.l mr.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE - mr.typeFullName shouldBe "java.lang.Object" + mr.typeFullName shouldBe "void" mr.lineNumber shouldBe Some(6) mr.columnNumber shouldBe Some(14) } @@ -148,15 +149,15 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a CALL node for `forEach` with the correct properties set" in { val List(c) = cpg.call.methodFullName(".*forEach.*").l - c.methodFullName shouldBe "java.lang.Iterable.forEach:void(kotlin.Function1)" - c.signature shouldBe "void(java.lang.Object)" + c.methodFullName shouldBe "kotlin.collections.forEach:void(java.lang.Iterable,kotlin.jvm.functions.Function1)" + c.signature shouldBe "void(java.lang.Iterable,kotlin.jvm.functions.Function1)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.lineNumber shouldBe Some(6) c.columnNumber shouldBe Some(4) val List(firstArg, secondArg) = cpg.call.methodFullName(".*forEach.*").argument.l - firstArg.argumentIndex shouldBe 0 - secondArg.argumentIndex shouldBe 1 + firstArg.argumentIndex shouldBe 1 + secondArg.argumentIndex shouldBe 2 } "should contain a TYPE_DECL node for the lambda with the correct props set" in { @@ -166,13 +167,15 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef td.inheritsFromTypeFullName shouldBe Seq("kotlin.Function1") Option(td.astParent).isDefined shouldBe true - val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.l - bm.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" + val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.dedup.l + bm.fullName shouldBe s"mypkg.foo.${Defines.ClosurePrefix}0:void(java.lang.String)" bm.name shouldBe s"${Defines.ClosurePrefix}0" - val List(b) = bm.refIn.collect { case r: Binding => r }.l - b.signature shouldBe "java.lang.Object(java.lang.Object)" - b.name shouldBe Constants.lambdaBindingName + val List(b1, b2) = bm.referencingBinding.l + b1.signature shouldBe "void(java.lang.String)" + b1.name shouldBe "invoke" + b2.signature shouldBe "java.lang.Object(java.lang.Object)" + b2.name shouldBe "invoke" } "should contain a METHOD_PARAMETER_IN for the lambda with referencing identifiers" in { @@ -183,46 +186,96 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "CPG for code with a scope function lambda with implicit parameter" should { val cpg = code(""" |package mypkg + | |fun f3(p: String): String { - | val out = p.apply { println(this) } - | return out + | val out = p.apply { println(this) } + | return out |} - ||""".stripMargin) + |""".stripMargin) "should contain a METHOD_PARAMETER_IN for the lambda with the correct properties set" in { val List(p) = cpg.method.fullName(".*lambda.*").parameter.l p.code shouldBe "this" - p.typeFullName shouldBe "ANY" + p.typeFullName shouldBe "java.lang.String" p.index shouldBe 1 } } + "lambda should contain METHOD_PARAMETER_IN for both implicit lambda parameters" in { + val cpg = code(""" + |package mypkg + | + |public fun myFunc(block: String.(Int) -> Unit): Unit {} + | fun outer(param: String): Unit { + | myFunc { println(it); println(this) + | } + |} + |""".stripMargin) + + val List(thisParam, itParam) = cpg.method.fullName(".*lambda.*").parameter.l + thisParam.code shouldBe "this" + thisParam.typeFullName shouldBe "java.lang.String" + thisParam.index shouldBe 1 + itParam.code shouldBe "it" + itParam.typeFullName shouldBe "int" + itParam.index shouldBe 2 + } + "CPG for code containing a lambda with parameter destructuring" should { - val cpg = code("""|package mypkg + val cpg = code(""" + |package mypkg | |fun f1(p: String) { - | val m = mapOf(p to 1, "two" to 2, "three" to 3) - | m.forEach { (k, v) -> - | println(k) - | } + | val m = mapOf(p to 1, "two" to 2, "three" to 3) + | m.forEach { (k, v) -> + | println(k) + | } |} |""".stripMargin) "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.f1.${Defines.ClosurePrefix}0:void(java.util.Map$$Entry)" + m.signature shouldBe "void(java.util.Map$Entry)" } "should contain METHOD_PARAMETER_IN nodes for the lambda with the correct properties set" in { - val List(p1, p2) = cpg.method.fullName(".*lambda.*").parameter.l - p1.code shouldBe "k" + val List(p1) = cpg.method.fullName(".*lambda.*").parameter.l + p1.code shouldBe s"${Constants.DestructedParamNamePrefix}1" p1.index shouldBe 1 - p1.typeFullName shouldBe "java.lang.String" - p2.code shouldBe "v" - p2.index shouldBe 2 - p2.typeFullName shouldBe "int" + p1.typeFullName shouldBe "java.util.Map$Entry" + } + + "should contain the correct initialization" in { + val List(_, _, localTmp, localIt, localK, localV) = cpg.method.fullName(".*lambda.*").local.l + localTmp.name shouldBe "tmp_1" + localTmp.typeFullName shouldBe "java.util.Map$Entry" + localIt.name shouldBe "it" + localIt.typeFullName shouldBe "java.util.Map$Entry" + localK.name shouldBe "k" + localK.typeFullName shouldBe "java.lang.String" + localV.name shouldBe "v" + localV.typeFullName shouldBe "int" + + val List(tmpAssignment, kAssignment, vAssignment) = cpg.method.fullName(".*lambda.*").ast.isCall.isAssignment.l + tmpAssignment.code shouldBe "tmp_1 = it" + val List(tmp, it) = tmpAssignment.astChildren.isIdentifier.l + tmp.typeFullName shouldBe "java.util.Map$Entry" + it.typeFullName shouldBe "java.util.Map$Entry" + + kAssignment.code shouldBe "k = tmp_1.component1()" + val List(k) = kAssignment.astChildren.isIdentifier.l + k.typeFullName shouldBe "java.lang.String" + + vAssignment.code shouldBe "v = tmp_1.component2()" + val List(v) = vAssignment.astChildren.isIdentifier.l + v.typeFullName shouldBe "int" + + cpg.identifier.filter(_._astIn.isEmpty) shouldBe empty + cpg.identifier.filter(_.refsTo.isEmpty) shouldBe empty + cpg.local.filter(_._astIn.isEmpty) shouldBe empty } + } "CPG for code containing a lambda with parameter destructuring and an `_` entry" should { @@ -239,15 +292,39 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.f1.${Defines.ClosurePrefix}0:void(java.util.Map$$Entry)" + m.signature shouldBe "void(java.util.Map$Entry)" } "should contain one METHOD_PARAMETER_IN node for the lambda with the correct properties set" in { val List(p1) = cpg.method.fullName(".*lambda.*").parameter.l - p1.code shouldBe "k" + p1.code shouldBe s"${Constants.DestructedParamNamePrefix}1" p1.index shouldBe 1 - p1.typeFullName shouldBe "java.lang.String" + p1.typeFullName shouldBe "java.util.Map$Entry" + } + + "should contain the correct initialization" in { + val List(_, _, localTmp, localIt, localK) = cpg.method.fullName(".*lambda.*").local.l + localTmp.name shouldBe "tmp_1" + localTmp.typeFullName shouldBe "java.util.Map$Entry" + localIt.name shouldBe "it" + localIt.typeFullName shouldBe "java.util.Map$Entry" + localK.name shouldBe "k" + localK.typeFullName shouldBe "java.lang.String" + + val List(tmpAssignment, kAssignment) = cpg.method.fullName(".*lambda.*").ast.isCall.isAssignment.l + tmpAssignment.code shouldBe "tmp_1 = it" + val List(tmp, it) = tmpAssignment.astChildren.isIdentifier.l + tmp.typeFullName shouldBe "java.util.Map$Entry" + it.typeFullName shouldBe "java.util.Map$Entry" + + kAssignment.code shouldBe "k = tmp_1.component1()" + val List(k) = kAssignment.astChildren.isIdentifier.l + k.typeFullName shouldBe "java.lang.String" + + cpg.identifier.filter(_._astIn.isEmpty) shouldBe empty + cpg.identifier.filter(_.refsTo.isEmpty) shouldBe empty + cpg.local.filter(_._astIn.isEmpty) shouldBe empty } } @@ -262,14 +339,14 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.throughTakeIf.${Defines.ClosurePrefix}0:boolean(java.lang.String)" + m.signature shouldBe "boolean(java.lang.String)" } "should contain a METHOD node for the lambda with a corresponding METHOD_RETURN which has the correct props set" in { val List(mr) = cpg.method.fullName(".*lambda.*").methodReturn.l mr.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE - mr.typeFullName shouldBe "java.lang.Object" + mr.typeFullName shouldBe "boolean" } "should contain a METHOD node for the lambda with a corresponding MODIFIER which has the correct props set" in { @@ -287,10 +364,10 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a CALL node for `takeIf` with the correct properties set" in { val List(c) = cpg.call.code("x.takeIf.*").l - c.methodFullName shouldBe "java.lang.Object.takeIf:java.lang.Object(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.takeIf:java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.typeFullName shouldBe "java.lang.String" - c.signature shouldBe "java.lang.Object(java.lang.Object)" + c.signature shouldBe "java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" } "should contain a RETURN node around as the last child of the lambda's BLOCK" in { @@ -308,13 +385,15 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef td.code shouldBe "LAMBDA_TYPE_DECL" Option(td.astParent).isDefined shouldBe true - val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.l - bm.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" + val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.dedup.l + bm.fullName shouldBe s"mypkg.throughTakeIf.${Defines.ClosurePrefix}0:boolean(java.lang.String)" bm.name shouldBe s"${Defines.ClosurePrefix}0" - val List(b) = bm.refIn.collect { case r: Binding => r }.l - b.signature shouldBe "java.lang.Object(java.lang.Object)" - b.name shouldBe Constants.lambdaBindingName + val List(b1, b2) = bm.referencingBinding.l + b1.signature shouldBe "boolean(java.lang.String)" + b1.name shouldBe "invoke" + b2.signature shouldBe "java.lang.Object(java.lang.Object)" + b2.name shouldBe "invoke" } "should contain a METHOD_PARAMETER_IN for the lambda with referencing identifiers" in { @@ -337,8 +416,8 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*0.*").l - m.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.fullName shouldBe s"mypkg.mappedListWith.${Defines.ClosurePrefix}0:java.lang.String(java.lang.String)" + m.signature shouldBe "java.lang.String(java.lang.String)" m.lineNumber shouldBe Some(6) m.columnNumber shouldBe Some(28) } @@ -346,7 +425,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda with a corresponding METHOD_RETURN which has the correct props set" in { val List(mr) = cpg.method.fullName(".*lambda.*").methodReturn.l mr.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE - mr.typeFullName shouldBe "java.lang.Object" + mr.typeFullName shouldBe "java.lang.String" mr.lineNumber shouldBe Some(6) mr.columnNumber shouldBe Some(28) } @@ -365,7 +444,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a CALL node for `map` with the correct properties set" in { val List(c) = cpg.call.methodFullName(".*map.*").take(1).l - c.methodFullName shouldBe "java.lang.Iterable.map:java.util.List(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.collections.map:java.util.List(java.lang.Iterable,kotlin.jvm.functions.Function1)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.lineNumber shouldBe Some(6) c.columnNumber shouldBe Some(20) @@ -376,13 +455,15 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef td.isExternal shouldBe false td.code shouldBe "LAMBDA_TYPE_DECL" - val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.l - bm.fullName shouldBe "mypkg..0:java.lang.Object(java.lang.Object)" + val List(bm) = cpg.typeDecl.fullName(".*lambda.*").boundMethod.dedup.l + bm.fullName shouldBe s"mypkg.mappedListWith.${Defines.ClosurePrefix}0:java.lang.String(java.lang.String)" bm.name shouldBe s"${Defines.ClosurePrefix}0" - val List(b) = bm.refIn.collect { case r: Binding => r }.l - b.signature shouldBe "java.lang.Object(java.lang.Object)" - b.name shouldBe Constants.lambdaBindingName + val List(b1, b2) = bm.referencingBinding.l + b1.signature shouldBe "java.lang.String(java.lang.String)" + b1.name shouldBe "invoke" + b2.signature shouldBe "java.lang.Object(java.lang.Object)" + b2.name shouldBe "invoke" } "should contain a METHOD_PARAMETER_IN for the lambda with referencing identifiers" in { @@ -425,7 +506,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the lambda with the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(int)" } } @@ -449,7 +530,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "should contain a METHOD node for the second lambda with the correct props set" in { val List(m) = cpg.method.fullName(".*lambda.*1.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(int)" } "should contain METHOD_REF nodes with the correct props set" in { @@ -496,29 +577,6 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef } } - "CPG for code with nested lambdas" should { - val cpg = code(""" - |package mypkg - | - |fun doSomething(p: String): Int { - | 1.let { - | 2.let { - | println(p) - | } - | } - | return 0 - |} - |""".stripMargin) - - "should contain a single LOCAL node inside the BLOCK of the first lambda" in { - cpg.method.fullName(".*lambda.*0.*").block.astChildren.isLocal.size shouldBe 1 - } - - "should contain two LOCAL nodes inside the BLOCK of the second lambda" in { - cpg.method.fullName(".*lambda.*1.*").block.astChildren.isLocal.size shouldBe 2 - } - } - "CPG for code with lambda with no statements in its block" should { val cpg = code(""" |package mypkg @@ -542,13 +600,16 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "contain a METHOD node for the lambda with the correct signature" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object()" + m.signature shouldBe "java.lang.String()" } "contain a BINDING node for the lambda with the correct signature" in { - val List(b1, b2) = cpg.typeDecl.methodBinding.l - b1.signature shouldBe "void(java.lang.String)" + val List(m) = cpg.method.fullName(".*lambda.*").l + val List(b1, b2) = m.referencingBinding.l + b1.signature shouldBe "java.lang.String()" + b1.name shouldBe "invoke" b2.signature shouldBe "java.lang.Object()" + b2.name shouldBe "invoke" } } @@ -563,10 +624,157 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef "contain a METHOD node for the lambda with a PARAMETER with implicit parameter name" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(java.lang.String)" val List(p) = m.parameter.l p.name shouldBe "it" + p.index shouldBe 1 } } + "test nested lambda full names" in { + val cpg = code(""" + |package mypkg + |val x = { i: Int -> + | val y = { j: Int -> + | j + | } + |} + |""".stripMargin) + val List(m1, m2) = cpg.method.fullName(".*lambda..*").l + m1.fullName shouldBe s"mypkg.x.${Defines.ClosurePrefix}0:void(int)" + m2.fullName shouldBe s"mypkg.x.${Defines.ClosurePrefix}0.${Defines.ClosurePrefix}1:int(int)" + } + + "CPG for code with lambda directly used as argument for interface parameter" should { + val cpg = code(""" + |package mypkg + |open class AAA + |class BBB: AAA() + |fun interface SomeInterface { + | fun method(param: T): T + |} + |fun interfaceUser(someInterface: SomeInterface) {} + |fun invoke() { + | interfaceUser { obj -> obj } + |} + |""".stripMargin) + + "contain correct lambda, bindings and type decl nodes" in { + val List(lambdaMethod) = cpg.method.fullName(".*lambda.*").l + lambdaMethod.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + lambdaMethod.signature shouldBe "mypkg.BBB(mypkg.BBB)" + + val List(lambdaTypeDecl) = lambdaMethod.bindingTypeDecl.dedup.l + lambdaTypeDecl.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0" + lambdaTypeDecl.inheritsFromTypeFullName should contain theSameElementsAs (List("mypkg.SomeInterface")) + + val List(binding1, binding2) = lambdaMethod.referencingBinding.l + binding1.name shouldBe "method" + binding1.signature shouldBe "mypkg.BBB(mypkg.BBB)" + binding1.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding1.bindingTypeDecl shouldBe lambdaTypeDecl + binding2.name shouldBe "method" + binding2.signature shouldBe "mypkg.AAA(mypkg.AAA)" + binding2.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding2.bindingTypeDecl shouldBe lambdaTypeDecl + } + } + + "CPG for code with wrapped lambda used as argument for interface parameter" should { + val cpg = code(""" + |package mypkg + |open class AAA + |class BBB: AAA() + |fun interface SomeInterface { + | fun method(param: T): T + |} + |fun interfaceUser(someInterface: SomeInterface) {} + |fun invoke() { + | interfaceUser(SomeInterface{ obj -> obj }) + |} + |""".stripMargin) + + "contain correct lambda, bindings and type decl nodes" in { + val List(lambdaMethod) = cpg.method.fullName(".*lambda.*").l + lambdaMethod.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + lambdaMethod.signature shouldBe "mypkg.BBB(mypkg.BBB)" + + val List(lambdaTypeDecl) = lambdaMethod.bindingTypeDecl.dedup.l + lambdaTypeDecl.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0" + lambdaTypeDecl.inheritsFromTypeFullName should contain theSameElementsAs (List("mypkg.SomeInterface")) + + val List(binding1, binding2) = lambdaMethod.referencingBinding.l + binding1.name shouldBe "method" + binding1.signature shouldBe "mypkg.BBB(mypkg.BBB)" + binding1.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding1.bindingTypeDecl shouldBe lambdaTypeDecl + binding2.name shouldBe "method" + binding2.signature shouldBe "mypkg.AAA(mypkg.AAA)" + binding2.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding2.bindingTypeDecl shouldBe lambdaTypeDecl + } + } + + "CPG for code with wrapped lambda assigned to local variable" should { + val cpg = code(""" + |package mypkg + |open class AAA + |class BBB: AAA() + |fun interface SomeInterface { + | fun method(param: T): T + |} + |fun invoke() { + | val aaa: SomeInterface = SomeInterface{ obj -> obj } + |} + |""".stripMargin) + + "contain correct lambda, bindings and type decl nodes" in { + val List(lambdaMethod) = cpg.method.fullName(".*lambda.*").l + lambdaMethod.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + lambdaMethod.signature shouldBe "mypkg.BBB(mypkg.BBB)" + + val List(lambdaTypeDecl) = lambdaMethod.bindingTypeDecl.dedup.l + lambdaTypeDecl.fullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0" + lambdaTypeDecl.inheritsFromTypeFullName should contain theSameElementsAs (List("mypkg.SomeInterface")) + + val List(binding1, binding2) = lambdaMethod.referencingBinding.l + binding1.name shouldBe "method" + binding1.signature shouldBe "mypkg.BBB(mypkg.BBB)" + binding1.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding1.bindingTypeDecl shouldBe lambdaTypeDecl + binding2.name shouldBe "method" + binding2.signature shouldBe "mypkg.AAA(mypkg.AAA)" + binding2.methodFullName shouldBe s"mypkg.invoke.${Defines.ClosurePrefix}0:mypkg.BBB(mypkg.BBB)" + binding2.bindingTypeDecl shouldBe lambdaTypeDecl + } + } + + "CPG for code with lambda wrapped in label" should { + val cpg = code(""" + |package mypkg + |fun outer() { + | listOf(1).forEach someLabel@{x: Int -> x} + |} + |""".stripMargin) + + "contain correct lambda, bindings and type decl nodes" in { + val List(lambdaMethod) = cpg.method.fullName(".*lambda.*").l + lambdaMethod.fullName shouldBe s"mypkg.outer.${Defines.ClosurePrefix}0:void(int)" + lambdaMethod.signature shouldBe "void(int)" + + val List(lambdaTypeDecl) = lambdaMethod.bindingTypeDecl.dedup.l + lambdaTypeDecl.fullName shouldBe s"mypkg.outer.${Defines.ClosurePrefix}0" + lambdaTypeDecl.inheritsFromTypeFullName should contain theSameElementsAs (List("kotlin.Function1")) + + val List(binding1, binding2) = lambdaMethod.referencingBinding.l + binding1.name shouldBe "invoke" + binding1.signature shouldBe "void(int)" + binding1.methodFullName shouldBe s"mypkg.outer.${Defines.ClosurePrefix}0:void(int)" + binding1.bindingTypeDecl shouldBe lambdaTypeDecl + binding2.name shouldBe "invoke" + binding2.signature shouldBe "java.lang.Object(java.lang.Object)" + binding2.methodFullName shouldBe s"mypkg.outer.${Defines.ClosurePrefix}0:void(int)" + binding2.bindingTypeDecl shouldBe lambdaTypeDecl + } + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LiteralTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LiteralTests.scala index ba4288736d87..5a5c8475318a 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LiteralTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LiteralTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LiteralTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalClassesTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalClassesTests.scala index bb57eb21932b..42a261fe5f83 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalClassesTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalClassesTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.TypeDecl -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocalClassesTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalTests.scala index a0c49d8842ca..0dc9da45c66e 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/LocalTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocalTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code simple local declarations" should { @@ -25,4 +25,20 @@ class LocalTests extends KotlinCode2CpgFixture(withOssDataflow = false) { l2.typeFullName shouldBe "int" } } + + "CPG for local declaration without initialization" should { + val cpg = code(""" + |fun main() { + | var x: Int + |} + |""".stripMargin) + + "contain LOCAL node for `x`" in { + val List(l1) = cpg.local("x").l + l1.code shouldBe "x" + l1.name shouldBe "x" + l1.typeFullName shouldBe "int" + } + } + } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MemberTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MemberTests.scala index 19457372c82d..4d0db375a81f 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MemberTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MemberTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MemberTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MetaDataTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MetaDataTests.scala index 40695e336c24..1f54baae7bc2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MetaDataTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MetaDataTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MetaDataTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodParameterTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodParameterTests.scala index 5e5db29df2dc..3a610e47ab1a 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodParameterTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodParameterTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodParameterTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodReturnTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodReturnTests.scala index d34372b3d2d9..fc332dade842 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodReturnTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodReturnTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodReturnTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodTests.scala index efc0e016c1e9..5acf980d988c 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/MethodTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Return} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MethodTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple method defined at package-level" should { @@ -31,9 +31,9 @@ class MethodTests extends KotlinCode2CpgFixture(withOssDataflow = false) { x.filename.endsWith(".kt") shouldBe true val List(y) = cpg.method.name("main").isExternal(false).l - y.fullName shouldBe "main:void(kotlin.Array)" + y.fullName shouldBe "main:void(java.lang.String[])" y.code shouldBe "main" - y.signature shouldBe "void(kotlin.Array)" + y.signature shouldBe "void(java.lang.String[])" y.isExternal shouldBe false y.lineNumber shouldBe Some(6) x.columnNumber shouldBe Some(4) @@ -174,7 +174,7 @@ class MethodTests extends KotlinCode2CpgFixture(withOssDataflow = false) { |""".stripMargin) "pass the lambda to a `sortedWith` call which is then under the method `sorted`" in { - inside(cpg.methodRef(".*.*").inCall.l) { + inside(cpg.methodRefWithName(".*.*").inCall.l) { case sortedWith :: Nil => sortedWith.name shouldBe "sortedWith" sortedWith.method.name shouldBe "sorted" @@ -182,4 +182,15 @@ class MethodTests extends KotlinCode2CpgFixture(withOssDataflow = false) { } } } + + "test correct translation of parameter kotlin type to java type" in { + val cpg = code(""" + |fun method(x: kotlin.CharArray) { + |} + |""".stripMargin) + + inside(cpg.method.name("method").l) { case List(method) => + method.fullName shouldBe "method:void(char[])" + } + } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ModifierTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ModifierTests.scala index 3b686142e41b..605185f136b9 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ModifierTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ModifierTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ModifierTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with various modifiers applied to various functions" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/NamespaceBlockTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/NamespaceBlockTests.scala index 33cf5c4be855..200ff2a9a41b 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/NamespaceBlockTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/NamespaceBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NamespaceBlockTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple namespace declaration" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectDeclarationsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectDeclarationsTests.scala index 81118e41fe2a..310aa4086717 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectDeclarationsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectDeclarationsTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ObjectDeclarationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple object declaration" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectExpressionTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectExpressionTests.scala index 044ea675d94a..a2b0e4e1d971 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectExpressionTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ObjectExpressionTests.scala @@ -1,6 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture +import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.codepropertygraph.generated.nodes.Identifier import io.shiftleft.codepropertygraph.generated.nodes.Local @@ -36,38 +37,38 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals "should contain TYPE_DECL nodes with the correct props set" in { val List(p1, p2, foo, o1, o2) = cpg.typeDecl.isExternal(false).nameNot("").l - p1.fullName shouldBe "mypkg.foo$object$1.printWithSuffix:void(java.lang.String)" - p2.fullName shouldBe "mypkg.foo$object$2.printWithSuffix:void(java.lang.String)" + p1.fullName shouldBe "mypkg.foo.object$0.printWithSuffix:void(java.lang.String)" + p2.fullName shouldBe "mypkg.foo.object$1.printWithSuffix:void(java.lang.String)" foo.fullName shouldBe "mypkg.foo:void(java.lang.String)" - o1.fullName shouldBe "mypkg.foo$object$1" - o2.fullName shouldBe "mypkg.foo$object$2" + o1.fullName shouldBe "mypkg.foo.object$0" + o2.fullName shouldBe "mypkg.foo.object$1" } "should contain two LOCAL nodes with the correct props set" in { val List(firstL: Local, secondL: Local) = cpg.local.l - firstL.typeFullName shouldBe "mypkg.foo$object$1" - secondL.typeFullName shouldBe "mypkg.foo$object$2" + firstL.typeFullName shouldBe "mypkg.foo.object$0" + secondL.typeFullName shouldBe "mypkg.foo.object$1" } "should contain two correctly-lowered representations of the assignments" in { val List(firstAssignment: Call, secondAssignment: Call) = cpg.method.nameExact("foo").call.methodFullNameExact(".assignment").l val List(firstAssignmentLHS: Identifier, firstAssignmentRHS: Call) = firstAssignment.argument.l: @unchecked - firstAssignmentLHS.typeFullName shouldBe "mypkg.foo$object$1" + firstAssignmentLHS.typeFullName shouldBe "mypkg.foo.object$0" firstAssignmentRHS.methodFullName shouldBe ".alloc" val List(secondAssignmentLHS: Identifier, secondAssignmentRHS: Call) = secondAssignment.argument.l: @unchecked - secondAssignmentLHS.typeFullName shouldBe "mypkg.foo$object$2" + secondAssignmentLHS.typeFullName shouldBe "mypkg.foo.object$1" secondAssignmentRHS.methodFullName shouldBe ".alloc" } "should contain two correctly-lowered calls to methods of the anonymous objects" in { val List(firstCall: Call, secondCall: Call) = cpg.call.methodFullName(".*printWithSuffix.*").l - firstCall.methodFullName shouldBe "mypkg.foo$object$1.printWithSuffix:void(java.lang.String)" - secondCall.methodFullName shouldBe "mypkg.foo$object$2.printWithSuffix:void(java.lang.String)" + firstCall.methodFullName shouldBe "mypkg.foo.object$0.printWithSuffix:void(java.lang.String)" + secondCall.methodFullName shouldBe "mypkg.foo.object$1.printWithSuffix:void(java.lang.String)" val List(firstCallLHS: Identifier, _: Identifier) = firstCall.argument.l: @unchecked - firstCallLHS.typeFullName shouldBe "mypkg.foo$object$1" + firstCallLHS.typeFullName shouldBe "mypkg.foo.object$0" val List(secondCallLHS: Identifier, _: Identifier) = secondCall.argument.l: @unchecked - secondCallLHS.typeFullName shouldBe "mypkg.foo$object$2" + secondCallLHS.typeFullName shouldBe "mypkg.foo.object$1" } } @@ -92,35 +93,35 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals val List(f1, does, f2, foo, interface, obj) = cpg.typeDecl.isExternal(false).nameNot("").l f1.fullName shouldBe "mypkg.AnInterface.doSomething:void(java.lang.String)" does.fullName shouldBe "mypkg.does:void(mypkg.AnInterface,java.lang.String)" - f2.fullName shouldBe "mypkg.foo$object$1.doSomething:void(java.lang.String)" + f2.fullName shouldBe "mypkg.foo.object$0.doSomething:void(java.lang.String)" foo.fullName shouldBe "mypkg.foo:void(java.lang.String)" interface.fullName shouldBe "mypkg.AnInterface" interface.inheritsFromTypeFullName shouldBe List("java.lang.Object") obj.name shouldBe "anonymous_obj" - obj.fullName shouldBe "mypkg.foo$object$1" + obj.fullName shouldBe "mypkg.foo.object$0" obj.inheritsFromTypeFullName shouldBe Seq("mypkg.AnInterface") val List(firstMethod: Method, secondMethod: Method) = obj.boundMethod.l - firstMethod.fullName shouldBe "mypkg.foo$object$1.doSomething:void(java.lang.String)" - secondMethod.fullName shouldBe "mypkg.foo$object$1.:void()" + firstMethod.fullName shouldBe "mypkg.foo.object$0.doSomething:void(java.lang.String)" + secondMethod.fullName shouldBe "mypkg.foo.object$0.:void()" } "contain a LOCAL node with the correct props set" in { val List(l: Local) = cpg.local.l l.name shouldBe "tmp_obj_1" - l.typeFullName shouldBe "mypkg.foo$object$1" + l.typeFullName shouldBe "mypkg.foo.object$0" } "contain a CALL node assigning a temp identifier to an alloc call" in { val List(firstAssignment: Call) = cpg.call.methodFullNameExact(".assignment").l val List(firstAssignmentLHS: Identifier, firstAssignmentRHS: Call) = firstAssignment.argument.l: @unchecked - firstAssignmentLHS.typeFullName shouldBe "mypkg.foo$object$1" + firstAssignmentLHS.typeFullName shouldBe "mypkg.foo.object$0" firstAssignmentRHS.methodFullName shouldBe ".alloc" } "contain a CALL node for an on the temp identifier" in { val List(c: Call) = cpg.call.nameExact("").l - c.methodFullName shouldBe "mypkg.foo$object$1.:void()" + c.methodFullName shouldBe "mypkg.foo.object$0.:void()" } } @@ -140,12 +141,12 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals "contain TYPE_DECL nodes with the correct props set" in { val List(interfaceF1, objectF1, f1, interface, qClass, obj) = cpg.typeDecl.isExternal(false).nameNot("").l interfaceF1.fullName shouldBe "mypkg.SomeInterface.doSomething:void()" - objectF1.fullName shouldBe "mypkg.f1$object$1.doSomething:void()" + objectF1.fullName shouldBe "mypkg.f1.object$0.doSomething:void()" f1.fullName shouldBe "mypkg.f1:void()" interface.fullName shouldBe "mypkg.SomeInterface" qClass.fullName shouldBe "mypkg.QClass" obj.name shouldBe "anonymous_obj" - obj.fullName shouldBe "mypkg.f1$object$1" + obj.fullName shouldBe "mypkg.f1.object$0" obj.inheritsFromTypeFullName shouldBe Seq("mypkg.SomeInterface") } } @@ -180,7 +181,7 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals c.methodFullName shouldBe "mypkg.PClass.addListener:void(mypkg.SomeInterface)" val List(objExpr: TypeDecl, l: Local, alloc: Call, init: Call, i: Identifier) = c.astChildren.isBlock.astChildren.l: @unchecked - objExpr.fullName shouldBe "mypkg.withFailListener$object$1" + objExpr.fullName shouldBe "mypkg.withFailListener.object$0" l.code shouldBe "tmp_obj_1" alloc.code shouldBe "tmp_obj_1 = " init.code shouldBe "" @@ -209,7 +210,7 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals c.methodFullName shouldBe "mypkg.addListener:void(mypkg.SomeInterface)" val List(objExpr: TypeDecl, l: Local, alloc: Call, init: Call, i: Identifier) = c.astChildren.isBlock.astChildren.l: @unchecked - objExpr.fullName shouldBe "mypkg.f1.$object$1" + objExpr.fullName shouldBe s"mypkg.f1.${Defines.ClosurePrefix}0.object$$1" l.code shouldBe "tmp_obj_1" alloc.code shouldBe "tmp_obj_1 = " init.code shouldBe "" @@ -231,7 +232,7 @@ class ObjectExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = fals | """.stripMargin) "contain a correctly lowered representation" in { - cpg.typeDecl.fullNameExact("mypkg.AN_OBJ$object$1").l should not be List() + cpg.typeDecl.fullNameExact("mypkg.AN_OBJ.object$0").l should not be List() } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ParenthesizedExpressionTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ParenthesizedExpressionTests.scala index da0f5858545c..bd1548c9e909 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ParenthesizedExpressionTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ParenthesizedExpressionTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ParenthesizedExpressionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/QualifiedExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/QualifiedExpressionsTests.scala index 824698ecce4a..570ec15821b2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/QualifiedExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/QualifiedExpressionsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class QualifiedExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with qualified expression with QE as a receiver" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ResolutionErrorsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ResolutionErrorsTests.scala index c4e676596e97..e8ff82618993 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ResolutionErrorsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ResolutionErrorsTests.scala @@ -4,7 +4,7 @@ import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.joern.kotlin2cpg.types.TypeConstants import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ResolutionErrorsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with QE of receiver for which the type cannot be inferred" should { @@ -22,7 +22,7 @@ class ResolutionErrorsTests extends KotlinCode2CpgFixture(withOssDataflow = fals "should contain a CALL node with an MFN starting with a placeholder type" in { val List(c) = cpg.call.slice(1, 2).l - c.methodFullName shouldBe Defines.UnresolvedNamespace + ".flatMap:ANY(ANY)" + c.methodFullName shouldBe s"${Defines.UnresolvedNamespace}.flatMap:${Defines.UnresolvedSignature}(1)" } } @@ -108,12 +108,12 @@ class ResolutionErrorsTests extends KotlinCode2CpgFixture(withOssDataflow = fals "should contain a CALL node with the correct MFN set when type info is available" in { val List(c) = cpg.call.methodFullName(Operators.assignment).where(_.argument(1).code("foo")).argument(2).isCall.l - c.methodFullName shouldBe "java.lang.Iterable.filter:java.util.List(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.collections.filter:java.util.List(java.lang.Iterable,kotlin.jvm.functions.Function1)" } "should contain a CALL node with the correct MFN set when type info is not available" in { val List(c) = cpg.call.methodFullName(Operators.assignment).where(_.argument(1).code("bar")).argument(2).isCall.l - c.methodFullName shouldBe Defines.UnresolvedNamespace + ".filter:ANY(ANY)" + c.methodFullName shouldBe s"${Defines.UnresolvedNamespace}.filter:${Defines.UnresolvedSignature}(1)" } } @@ -134,7 +134,7 @@ class ResolutionErrorsTests extends KotlinCode2CpgFixture(withOssDataflow = fals "should contain a METHOD node with a MFN property starting with `kotlin.Any`" in { val List(m) = cpg.method.fullName(".*getFileSize.*").l - m.fullName shouldBe s"${Defines.UnresolvedNamespace}.getFileSize:int(boolean)" + m.fullName shouldBe s"mypkg.getFileSize:${Defines.UnresolvedSignature}(1)" } } @@ -156,7 +156,7 @@ class ResolutionErrorsTests extends KotlinCode2CpgFixture(withOssDataflow = fals "should contain a METHOD node with a MFN property that replaced the unresolvable types with `kotlin.Any`" in { val List(m) = cpg.method.fullName(".*clone.*").take(1).l - m.fullName shouldBe "java.util.Map.clone:java.util.Map()" + m.fullName shouldBe "mypkg.clone:java.util.Map(java.util.Map)" } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SafeQualifiedExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SafeQualifiedExpressionsTests.scala index 28267aa5ef34..f340e39dd835 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SafeQualifiedExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SafeQualifiedExpressionsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SafeQualifiedExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ScopeFunctionTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ScopeFunctionTests.scala index 8d5a75d91e27..8285760d897f 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ScopeFunctionTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ScopeFunctionTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Block, Return} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code call to `also` scope function without an explicitly-defined parameter" should { @@ -13,13 +13,13 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) p.name shouldBe "it" } - "should NOT contain a RETURN node around as the last child of the lambda's BLOCK" in { + "should contain a RETURN node around as the last child of the lambda's BLOCK" in { val List(b: Block) = cpg.method.fullName(".*lambda.*").block.l val hasReturnAsLastChild = b.astChildren.last match { case _: Return => true case _ => false } - hasReturnAsLastChild shouldBe false + hasReturnAsLastChild shouldBe true } } @@ -31,13 +31,13 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) p.name shouldBe "this" } - "should NOT contain a RETURN node around as the last child of the lambda's BLOCK" in { + "should contain a RETURN node around as the last child of the lambda's BLOCK" in { val List(b: Block) = cpg.method.fullName(".*lambda.*").block.l val hasReturnAsLastChild = b.astChildren.last match { case _: Return => true case _ => false } - hasReturnAsLastChild shouldBe false + hasReturnAsLastChild shouldBe true } } @@ -106,7 +106,7 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) "should contain a METHOD node with the correct signature" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(int)" } } @@ -140,7 +140,7 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) "should contain a METHOD node with the correct signature" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(int)" } } @@ -174,7 +174,7 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) "should contain a METHOD node with the correct signature" in { val List(m) = cpg.method.fullName(".*lambda.*").l - m.signature shouldBe "java.lang.Object(java.lang.Object)" + m.signature shouldBe "void(int)" } } @@ -212,7 +212,7 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false) "should X" in { val List(c) = cpg.call.code("x.takeIf.*").l - c.methodFullName shouldBe "java.lang.Object.takeIf:java.lang.Object(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.takeIf:java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SpecialOperatorsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SpecialOperatorsTests.scala index a307d4b01dda..78078f2cf511 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SpecialOperatorsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SpecialOperatorsTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SpecialOperatorsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StdLibTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StdLibTests.scala index 72498ccfea49..94f5e363daff 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StdLibTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StdLibTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with call to `takeIf`" should { @@ -27,9 +27,9 @@ class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node with the correct METHOD_FULL_NAME for `takeIf`" in { val List(c) = cpg.call.code("x.takeIf.*").l - c.methodFullName shouldBe "java.lang.Object.takeIf:java.lang.Object(kotlin.Function1)" + c.methodFullName shouldBe "kotlin.takeIf:java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - c.signature shouldBe "java.lang.Object(java.lang.Object)" + c.signature shouldBe "java.lang.Object(java.lang.Object,kotlin.jvm.functions.Function1)" c.typeFullName shouldBe "java.util.UUID" } } @@ -114,23 +114,18 @@ class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { |package mypkg | |fun foo() { - | val numbersMap = mapOf("key1" to 1, "key2" to 2, "key3" to 3, "key4" to 1) + | val numbersMap = mapOf("key1" to 1, "key2" to 2) | println(numbersMap) |} |""".stripMargin) "should contain CALL nodes for calls to infix fn `to`" in { val List(c1) = cpg.call.code("\"key1.*").l - c1.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object)" + c1.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object,java.lang.Object)" + c1.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH val List(c2) = cpg.call.code("\"key2.*").l - c2.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object)" - - val List(c3) = cpg.call.code("\"key3.*").l - c3.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object)" - - val List(c4) = cpg.call.code("\"key4.*").l - c4.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object)" + c2.methodFullName shouldBe "kotlin.to:kotlin.Pair(java.lang.Object,java.lang.Object)" } "CPG for code with calls to stdlib's `split`s" should { @@ -147,8 +142,10 @@ class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { |""".stripMargin) "should contain CALL nodes for `split` with the correct MFNs set" in { - cpg.call.methodFullName(".*split.*").methodFullName.toSet shouldBe - Set("java.lang.CharSequence.split:java.util.List(kotlin.Array,boolean,int)") + inside(cpg.call.methodFullName(".*split.*").l) { case List(call1, call2) => + call1.methodFullName shouldBe "kotlin.text.split:java.util.List(java.lang.CharSequence,java.lang.String[],boolean,int)" + call2.methodFullName shouldBe call1.methodFullName + } } } @@ -170,8 +167,8 @@ class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node for `trim` with the correct props set" in { val List(c) = cpg.call.code("p.trim.*").l - c.methodFullName shouldBe "java.lang.String.trim:java.lang.String()" - c.signature shouldBe "java.lang.String()" + c.methodFullName shouldBe "kotlin.text.trim:java.lang.String(java.lang.String)" + c.signature shouldBe "java.lang.String(java.lang.String)" c.typeFullName shouldBe "java.lang.String" c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH c.lineNumber shouldBe Some(5) @@ -180,7 +177,7 @@ class StdLibTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "should contain a CALL node for `trim` a receiver arg with the correct props set" in { val List(receiverArg) = cpg.call.code("p.trim.*").argument.isIdentifier.l - receiverArg.argumentIndex shouldBe 0 + receiverArg.argumentIndex shouldBe 1 receiverArg.name shouldBe "p" receiverArg.code shouldBe "p" receiverArg.typeFullName shouldBe "java.lang.String" diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StringInterpolationTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StringInterpolationTests.scala index 644435182656..7253fbbf3221 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StringInterpolationTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/StringInterpolationTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class StringInterpolationTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SuperTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SuperTests.scala index 79de86e710b2..00d832760d29 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SuperTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/SuperTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SuperTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple call using _super_" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ThisTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ThisTests.scala index 8af6d9af6548..d976133041d6 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ThisTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ThisTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ThisTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with calls to functions of same name, but different scope" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TryExpressionsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TryExpressionsTests.scala index ce5b2fc7ce02..85761c2c8a6b 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TryExpressionsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TryExpressionsTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TryExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with simple `try`-expression" should { @@ -49,7 +49,7 @@ class TryExpressionsTests extends KotlinCode2CpgFixture(withOssDataflow = false) firstAstChildOfSecondArg.order shouldBe 1 firstAstChildOfSecondArg.name shouldBe "toInt" firstAstChildOfSecondArg.code shouldBe "r.toInt()" - firstAstChildOfSecondArg.methodFullName shouldBe "java.lang.String.toInt:int()" + firstAstChildOfSecondArg.methodFullName shouldBe "kotlin.text.toInt:int(java.lang.String)" } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeAliasTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeAliasTests.scala index d76b92786d78..3af134f92b59 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeAliasTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeAliasTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.joern.x2cpg.Defines -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TypeAliasTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDefaultJars = true) { "CPG for code with simple typealias to Int" should { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeDeclTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeDeclTests.scala index e05da10d7311..3ee7dc51b82d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeDeclTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeDeclTests.scala @@ -11,7 +11,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ MethodParameterIn } import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal class TypeDeclTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeTests.scala index e529c84539ba..21c49b78a9d2 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/TypeTests.scala @@ -66,48 +66,4 @@ class TypeTests extends KotlinCode2CpgFixture(withOssDataflow = false) { x.name shouldBe "l" } } - - "generics with 'keep type arguments' config" should { - - "show the fully qualified type arguments for stdlib `List and `Map` objects" in { - val cpg = code(""" - |import java.util.ArrayList - |import java.util.HashMap - | - |fun foo() { - | val stringList = ArrayList() - | val stringIntMap = HashMap() - |} - |""".stripMargin) - .withConfig(Config().withKeepTypeArguments(true)) - - cpg.identifier("stringList").typeFullName.head shouldBe "java.util.ArrayList" - cpg.identifier("stringIntMap").typeFullName.head shouldBe "java.util.HashMap" - } - - "show the fully qualified names of external types" in { - val cpg = code(""" - |import org.apache.flink.streaming.api.datastream.DataStream - |import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment - |import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer - |import org.apache.flink.streaming.util.serialization.SimpleStringSchema - | - |import java.util.Properties; - | - |class FlinkKafkaExample { - | fun main() { - | val kafkaProducer = FlinkKafkaProducer("kafka-topic") - | } - |} - |""".stripMargin).withConfig(Config().withKeepTypeArguments(true)) - - cpg.call - .codeExact("FlinkKafkaProducer(\"kafka-topic\")") - .filterNot(_.name == Operators.alloc) - .map(_.methodFullName) - .head shouldBe "org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer:ANY(ANY)" - } - - } - } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/UnaryOpTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/UnaryOpTests.scala index 7c6aa3617877..4ca47bfc65aa 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/UnaryOpTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/UnaryOpTests.scala @@ -2,7 +2,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class UnaryOpTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/testfixtures/KotlinCodeToCpgFixture.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/testfixtures/KotlinCodeToCpgFixture.scala index 486654812a20..713bcfd868dc 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/testfixtures/KotlinCodeToCpgFixture.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/testfixtures/KotlinCodeToCpgFixture.scala @@ -1,8 +1,9 @@ package io.joern.kotlin2cpg.testfixtures import better.files.File as BFile +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.* -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.SemanticCpgTestFixture import io.joern.dataflowengineoss.testfixtures.SemanticTestCpg import io.joern.kotlin2cpg.Config @@ -57,14 +58,14 @@ class KotlinCode2CpgFixture( withOssDataflow: Boolean = false, withDefaultJars: Boolean = false, withPostProcessing: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty + semantics: Semantics = DefaultSemantics() ) extends Code2CpgFixture(() => new KotlinTestCpg(withDefaultJars) .withOssDataflow(withOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) - with SemanticCpgTestFixture(extraFlows) { + with SemanticCpgTestFixture(semantics) { protected def flowToResultPairs(path: Path): List[(String, Option[Int])] = path.resultPairs() } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/DefaultRegisteredTypesTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/DefaultRegisteredTypesTests.scala similarity index 85% rename from joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/DefaultRegisteredTypesTests.scala rename to joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/DefaultRegisteredTypesTests.scala index 1978c0e3d840..6219f73c175a 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/DefaultRegisteredTypesTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/DefaultRegisteredTypesTests.scala @@ -1,7 +1,7 @@ -package io.joern.kotlin2cpg +package io.joern.kotlin2cpg.types import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DefaultRegisteredTypesTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/KotlinScriptFilteringTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/KotlinScriptFilteringTests.scala deleted file mode 100644 index 7996bbc123a6..000000000000 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/types/KotlinScriptFilteringTests.scala +++ /dev/null @@ -1,35 +0,0 @@ -package io.joern.kotlin2cpg.types - -import io.joern.kotlin2cpg.compiler.{CompilerAPI, ErrorLoggingMessageCollector} -import org.jetbrains.kotlin.resolve.BindingContext -import org.scalatest.freespec.AnyFreeSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.Ignore - -@Ignore // uncomment as soon as the sourceDir path is correct -class KotlinScriptFilteringTests extends AnyFreeSpec with Matchers { - "Running `CompilerAPI.makeEnvironment` on external project with lots of KotlinScript sources" - { - "should return an empty binding context" in { - val sourceDir = "src/test/resources/external_projects/kotlin-dsl" - val environment = - CompilerAPI.makeEnvironment(Seq(sourceDir), Seq(), Seq(), new ErrorLoggingMessageCollector) - environment.getSourceFiles should not be List() - - val nameGenerator = new DefaultTypeInfoProvider(environment) - nameGenerator.bindingContext should not be null - nameGenerator.bindingContext shouldBe BindingContext.EMPTY - } - - "should not return an empty binding context" in { - val sourceDir = "src/test/resources/external_projects/kotlin-dsl" - val dirsForSourcesToCompile = ContentSourcesPicker.dirsForRoot(sourceDir) - val environment = - CompilerAPI.makeEnvironment(dirsForSourcesToCompile, Seq(), Seq(), new ErrorLoggingMessageCollector) - environment.getSourceFiles should not be List() - - val nameGenerator = new DefaultTypeInfoProvider(environment) - nameGenerator.bindingContext should not be null - nameGenerator.bindingContext should not be BindingContext.EMPTY - } - } -} diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/DefaultImportsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/DefaultImportsTests.scala index 2e381352d259..607313386ce5 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/DefaultImportsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/DefaultImportsTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.validation import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class DefaultImportsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { // It tests if we take into consideration default imports: https://kotlinlang.org/docs/packages.html#default-imports diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/IdentifierReferencesTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/IdentifierReferencesTests.scala index c24e30d6d6b7..02e94913b2ae 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/IdentifierReferencesTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/IdentifierReferencesTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.validation import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Local, MethodParameterIn} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* // TODO: also add test with refs inside TYPE_DECL diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/MissingTypeInformationTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/MissingTypeInformationTests.scala index c435499d3310..36cc8397b270 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/MissingTypeInformationTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/MissingTypeInformationTests.scala @@ -3,7 +3,7 @@ package io.joern.kotlin2cpg.validation import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class MissingTypeInformationTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with CALL to Java stdlib fn with argument of unknown type" should { @@ -59,10 +59,9 @@ class MissingTypeInformationTests extends KotlinCode2CpgFixture(withOssDataflow |""".stripMargin) "contain METHODs node for the constructors with the METHOD_FULL_NAMEs set" in { - val List(m1: Method, m2: Method, m3: Method) = cpg.method.nameExact("").l + val List(m1: Method, m2: Method) = cpg.method.nameExact("").l m1.fullName shouldBe "com.insecureshop.CartAdapter.:void()" - m2.fullName shouldBe "com.insecureshop.CartAdapter.CartViewHolder.:void(com.insecureshop.databinding.CartItemBinding)" - m3.fullName shouldBe "com.insecureshop.CartAdapter.CartViewHolder.:void(ANY)" + m2.fullName shouldBe s"com.insecureshop.CartAdapter$$CartViewHolder.:${Defines.UnresolvedSignature}(1)" } } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/PrimitiveArrayTypeMappingTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/PrimitiveArrayTypeMappingTests.scala index afb57d324c9f..aeffe8b32052 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/PrimitiveArrayTypeMappingTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/PrimitiveArrayTypeMappingTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.validation import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class PrimitiveArrayTypeMappingTests extends KotlinCode2CpgFixture(withOssDataflow = false) { "CPG for code with usage of `kotlin.BooleanArray`" should { @@ -53,7 +53,7 @@ class PrimitiveArrayTypeMappingTests extends KotlinCode2CpgFixture(withOssDatafl "should contain a CALL node with a METHOD_FULL_NAME starting with `kotlin.ByteArray`" in { val List(c) = cpg.call.code("byte.*toString.*").l - c.methodFullName shouldBe "kotlin.ByteArray.toString:java.lang.String(java.nio.charset.Charset)" + c.methodFullName shouldBe "kotlin.collections.toString:java.lang.String(byte[],java.nio.charset.Charset)" } } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/UnitTypeMappingTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/UnitTypeMappingTests.scala index dc2f0f467378..8ee9dde0d2e1 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/UnitTypeMappingTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/UnitTypeMappingTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.validation import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class UnitTypeMappingTests extends KotlinCode2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/ValidationTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/ValidationTests.scala index 2f8d27ee235b..a36e8957f025 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/ValidationTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/validation/ValidationTests.scala @@ -723,6 +723,7 @@ class ValidationTests extends KotlinCode2CpgFixture(withOssDataflow = false) { .fullNameNot(".*.*") .fullNameNot(".*") .fullNameNot(".*") + .fullNameNot(".*.*") .fullName(".*>.*") .fullName .l shouldBe List() @@ -734,6 +735,7 @@ class ValidationTests extends KotlinCode2CpgFixture(withOssDataflow = false) { .methodFullNameNot(".*.*") .methodFullNameNot(".*") .methodFullNameNot(".*") + .methodFullNameNot(".*.*") .methodFullName(".*>.*") .methodFullName .l shouldBe List() diff --git a/joern-cli/frontends/php2cpg/build.sbt b/joern-cli/frontends/php2cpg/build.sbt index 203ef91a60e1..c95e07becf50 100644 --- a/joern-cli/frontends/php2cpg/build.sbt +++ b/joern-cli/frontends/php2cpg/build.sbt @@ -4,20 +4,19 @@ import better.files.File name := "php2cpg" -val phpParserVersion = "4.15.8" val upstreamParserBinName = "php-parser.phar" -val versionedParserBinName = s"php-parser-$phpParserVersion.phar" +val versionedParserBinName = s"php-parser-${Versions.phpParser}.phar" val phpParserDlUrl = - s"https://github.com/joernio/PHP-Parser/releases/download/v$phpParserVersion/$upstreamParserBinName" + s"https://github.com/joernio/PHP-Parser/releases/download/v${Versions.phpParser}/$upstreamParserBinName" dependsOn(Projects.dataflowengineoss % "compile->compile;test->test", Projects.x2cpg % "compile->compile;test->test") libraryDependencies ++= Seq( - "com.lihaoyi" %% "upickle" % Versions.upickle, - "com.lihaoyi" %% "ujson" % Versions.upickle, - "io.shiftleft" %% "codepropertygraph" % Versions.cpg, + "com.lihaoyi" %% "upickle" % Versions.upickle, + "com.lihaoyi" %% "ujson" % Versions.upickle, + "io.shiftleft" %% "codepropertygraph" % Versions.cpg, "com.github.sh4869" %% "semver-parser-scala" % Versions.semverParser, - "org.scalatest" %% "scalatest" % Versions.scalatest % Test + "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) lazy val phpParseInstallTask = taskKey[Unit]("Install PHP-Parse using PHP Composer") @@ -37,3 +36,7 @@ Compile / compile := ((Compile / compile) dependsOn phpParseInstallTask).value enablePlugins(JavaAppPackaging, LauncherJarPlugin) Global / onChangedBuildSource := ReloadOnSourceChanges + +/** write the php parser version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("PHP-Parser-Version") -> Versions.phpParser) diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Main.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Main.scala index db3ade3308b1..afe2947167a6 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Main.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Main.scala @@ -3,6 +3,7 @@ package io.joern.php2cpg import io.joern.php2cpg.Frontend.* import io.joern.x2cpg.passes.frontend.* import io.joern.x2cpg.{DependencyDownloadConfig, X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser /** Command line configuration parameters @@ -51,8 +52,12 @@ object Frontend { } } -object Main extends X2CpgMain(cmdLineParser, new Php2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new Php2Cpg()) with FrontendHTTPServer[Config, Php2Cpg] { + + override protected def newDefaultConfig(): Config = Config() + def run(config: Config, php2Cpg: Php2Cpg): Unit = { - php2Cpg.run(config) + if (config.serverMode) { startup() } + else { php2Cpg.run(config) } } } diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Php2Cpg.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Php2Cpg.scala index d0230fd616d5..c2b0b759be25 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Php2Cpg.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/Php2Cpg.scala @@ -22,7 +22,7 @@ class Php2Cpg extends X2CpgFrontend[Config] { private val PhpVersionRegex = new Regex("^PHP ([78]\\.[1-9]\\.[0-9]|[9-9]\\d\\.\\d\\.\\d)") private def isPhpVersionSupported: Boolean = { - val result = ExternalCommand.run("php --version", ".") + val result = ExternalCommand.run(Seq("php", "--version"), ".").toTry result match { case Success(listString) => val phpVersionStr = listString.headOption.getOrElse("") diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/astcreation/AstCreator.scala index c821740b5c9c..c5ad8d6a9cb1 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/astcreation/AstCreator.scala @@ -8,30 +8,37 @@ import io.joern.php2cpg.utils.Scope import io.joern.x2cpg.Ast.storeInDiffGraph import io.joern.x2cpg.Defines.{StaticInitMethodName, UnresolvedNamespace, UnresolvedSignature} import io.joern.x2cpg.utils.AstPropertiesUtil.RootProperties +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.utils.NodeBuilders.* -import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode} +import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.passes.IntervalKeyPool import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal +import io.shiftleft.utils.IOUtils import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate import java.nio.charset.StandardCharsets +import java.nio.file.Path +import scala.collection.mutable -class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], disableFileContent: Boolean)(implicit +class AstCreator(relativeFileName: String, fileName: String, phpAst: PhpFile, disableFileContent: Boolean)(implicit withSchemaValidation: ValidationMode -) extends AstCreatorBase(filename) +) extends AstCreatorBase(relativeFileName) with AstNodeBuilder[PhpNode, AstCreator] { private val logger = LoggerFactory.getLogger(AstCreator.getClass) private val scope = new Scope()(() => nextClosureName()) private val tmpKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) private val globalNamespace = globalNamespaceBlock() + private var fileContent = Option.empty[String] private def getNewTmpName(prefix: String = "tmp"): String = s"$prefix${tmpKeyPool.next.toString}" - override def createAst(): BatchedUpdate.DiffGraphBuilder = { + override def createAst(): DiffGraphBuilder = { + if (!disableFileContent) { + fileContent = Option(IOUtils.readEntireFile(Path.of(fileName))) + } + val ast = astForPhpFile(phpAst) storeInDiffGraph(ast, diffGraph) diffGraph @@ -51,7 +58,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], file, globalNamespace.name, globalNamespace.fullName, - filename, + relativeFileName, globalNamespace.code, NodeTypes.NAMESPACE_BLOCK, globalNamespace.fullName @@ -75,7 +82,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], } private def astForPhpFile(file: PhpFile): Ast = { - val fileNode = NewFile().name(filename) + val fileNode = NewFile().name(relativeFileName) fileContent.foreach(fileNode.content(_)) scope.pushNewScope(globalNamespace) @@ -135,7 +142,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], case enumCase: PhpEnumCaseStmt => astForEnumCase(enumCase) :: Nil case staticStmt: PhpStaticStmt => astsForStaticStmt(staticStmt) case unhandled => - logger.error(s"Unhandled stmt $unhandled in $filename") + logger.error(s"Unhandled stmt $unhandled in $relativeFileName") ??? } } @@ -231,7 +238,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], } val methodCode = s"${modifierString}function $methodName(${parameters.map(_.rootCodeOrEmpty).mkString(",")})" - val method = methodNode(decl, methodName, methodCode, fullName, Some(signature), filename) + val method = methodNode(decl, methodName, methodCode, fullName, Some(signature), relativeFileName) scope.pushNewScope(method) @@ -477,7 +484,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], private def astForNamespaceStmt(stmt: PhpNamespaceStmt): Ast = { val name = stmt.name.map(_.name).getOrElse(NameConstants.Unknown) - val fullName = s"$filename:$name" + val fullName = s"$relativeFileName:$name" val namespaceBlock = NewNamespaceBlock() .name(name) @@ -574,6 +581,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], private def astForForeachStmt(stmt: PhpForeachStmt): Ast = { val iterIdentifier = getTmpIdentifier(stmt, maybeTypeFullName = None, prefix = "iter_") + // keep this just used to construct the `code` field val assignItemTargetAst = stmt.keyVar match { case Some(key) => astForKeyValPair(key, stmt.valueVar, line(stmt)) case None => astForExpr(stmt.valueVar) @@ -585,7 +593,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val iteratorAssignAst = simpleAssignAst(Ast(iterIdentifier), iterValue, line(stmt)) // - Assigned item assign - val itemInitAst = getItemAssignAstForForeach(stmt, assignItemTargetAst, iterIdentifier.copy) + val itemInitAst = getItemAssignAstForForeach(stmt, iterIdentifier.copy) // Condition ast val isNullName = PhpOperators.isNull @@ -614,7 +622,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val itemUpdateAst = itemInitAst.root match { case Some(initRoot: AstNodeNew) => itemInitAst.subTreeCopy(initRoot) case _ => - logger.warn(s"Could not copy foreach init ast in $filename") + logger.warn(s"Could not copy foreach init ast in $relativeFileName") Ast() } @@ -631,34 +639,63 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], .withConditionEdges(foreachNode, conditionAst.root.toList) } - private def getItemAssignAstForForeach( - stmt: PhpForeachStmt, - assignItemTargetAst: Ast, - iteratorIdentifier: NewIdentifier - ): Ast = { - val iteratorIdentifierAst = Ast(iteratorIdentifier) - val currentCallSignature = s"$UnresolvedSignature(0)" - val currentCallCode = s"${iteratorIdentifierAst.rootCodeOrEmpty}->current()" - val currentCallNode = callNode( - stmt, - currentCallCode, - "current", - "Iterator.current", - DispatchTypes.DYNAMIC_DISPATCH, - Some(currentCallSignature), - Some(TypeConstants.Any) + private def getItemAssignAstForForeach(stmt: PhpForeachStmt, iteratorIdentifier: NewIdentifier): Ast = { + // create assignment for value-part + val valueAssign = { + val iteratorIdentifierAst = Ast(iteratorIdentifier) + val currentCallSignature = s"$UnresolvedSignature(0)" + val currentCallCode = s"${iteratorIdentifierAst.rootCodeOrEmpty}->current()" + // `current` function is used to get the current element of given array + // see https://www.php.net/manual/en/function.current.php & https://www.php.net/manual/en/iterator.current.php + val currentCallNode = callNode( + stmt, + currentCallCode, + "current", + "Iterator.current", + DispatchTypes.DYNAMIC_DISPATCH, + Some(currentCallSignature), + Some(TypeConstants.Any) + ) + val currentCallAst = callAst(currentCallNode, base = Option(iteratorIdentifierAst)) + + val valueAst = if (stmt.assignByRef) { + val addressOfCode = s"&${currentCallAst.rootCodeOrEmpty}" + val addressOfCall = newOperatorCallNode(Operators.addressOf, addressOfCode, line = line(stmt)) + callAst(addressOfCall, currentCallAst :: Nil) + } else { + currentCallAst + } + simpleAssignAst(astForExpr(stmt.valueVar), valueAst, line(stmt)) + } + + // try to create assignment for key-part + val keyAssignOption = stmt.keyVar.map(keyVar => + val iteratorIdentifierAst = Ast(iteratorIdentifier.copy) + val keyCallSignature = s"$UnresolvedSignature(0)" + val keyCallCode = s"${iteratorIdentifierAst.rootCodeOrEmpty}->key()" + // `key` function is used to get the key of the current element + // see https://www.php.net/manual/en/function.key.php & https://www.php.net/manual/en/iterator.key.php + val keyCallNode = callNode( + stmt, + keyCallCode, + "key", + "Iterator.key", + DispatchTypes.DYNAMIC_DISPATCH, + Some(keyCallSignature), + Some(TypeConstants.Any) + ) + val keyCallAst = callAst(keyCallNode, base = Option(iteratorIdentifierAst)) + simpleAssignAst(astForExpr(keyVar), keyCallAst, line(stmt)) ) - val currentCallAst = callAst(currentCallNode, base = Option(iteratorIdentifierAst)) - val valueAst = if (stmt.assignByRef) { - val addressOfCode = s"&${currentCallAst.rootCodeOrEmpty}" - val addressOfCall = newOperatorCallNode(Operators.addressOf, addressOfCode, line = line(stmt)) - callAst(addressOfCall, currentCallAst :: Nil) - } else { - currentCallAst + keyAssignOption match { + case Some(keyAssign) => + Ast(blockNode(stmt)) + .withChild(keyAssign) + .withChild(valueAssign) + case None => + valueAssign } - - simpleAssignAst(assignItemTargetAst, valueAst, line(stmt)) } private def simpleAssignAst(target: Ast, source: Ast, lineNo: Option[Int]): Ast = { @@ -716,7 +753,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], Ast(local) :: assignmentAst.toList case other => - logger.warn(s"Unexpected static variable type $other in $filename") + logger.warn(s"Unexpected static variable type $other in $relativeFileName") Nil } } @@ -753,7 +790,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], prependNamespacePrefix(name.name) } - val typeDecl = typeDeclNode(stmt, name.name, fullName, filename, code, inherits = inheritsFrom) + val typeDecl = typeDeclNode(stmt, name.name, fullName, relativeFileName, code, inherits = inheritsFrom) val createDefaultConstructor = stmt.hasConstructor scope.pushNewScope(typeDecl) @@ -772,7 +809,16 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], case inits => val signature = s"${TypeConstants.Void}()" val fullName = composeMethodFullName(StaticInitMethodName, isStatic = true) - val ast = staticInitMethodAst(inits, fullName, Option(signature), TypeConstants.Void, fileName = Some(filename)) + val ast = + staticInitMethodAst(inits, fullName, Option(signature), TypeConstants.Void, fileName = Some(relativeFileName)) + + for { + method <- ast.root.collect { case method: NewMethod => method } + content <- fileContent + } { + method.offset(0) + method.offsetEnd(content.length) + } Option(ast) } @@ -837,7 +883,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val thisParam = thisParamAstForMethod(originNode) - val method = methodNode(originNode, ConstructorMethodName, fullName, fullName, Some(signature), filename) + val method = methodNode(originNode, ConstructorMethodName, fullName, fullName, Some(signature), relativeFileName) val methodBody = blockAst(blockNode(originNode), scope.getFieldInits) @@ -855,11 +901,15 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val fieldIdentifier = newFieldIdentifierNode(memberNode.name, memberNode.lineNumber) callAst(fieldAccessNode, List(identifier, fieldIdentifier).map(Ast(_))).withRefEdges(identifier, thisParam.toList) } else { - val identifierCode = memberNode.code.replaceAll("const ", "").replaceAll("case ", "") - val typeFullName = Option(memberNode.typeFullName) - val identifier = newIdentifierNode(memberNode.name, typeFullName.getOrElse("ANY")) - .code(identifierCode) - Ast(identifier).withRefEdge(identifier, memberNode) + val selfIdentifier = { + val name = "self" + val typ = scope.getEnclosingTypeDeclTypeName + newIdentifierNode(name, typ.getOrElse(Defines.Any), typ.toList, memberNode.lineNumber).code(name) + } + val fieldIdentifier = newFieldIdentifierNode(memberNode.name, memberNode.lineNumber) + val code = s"self::${memberNode.code.replaceAll("(static|case|const) ", "")}" + val fieldAccessNode = newOperatorCallNode(Operators.fieldAccess, code, line = memberNode.lineNumber) + callAst(fieldAccessNode, List(selfIdentifier, fieldIdentifier).map(Ast(_))) } val value = astForExpr(valueExpr) @@ -878,7 +928,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val name = constDecl.name.name val code = s"const $name" val someValue = Option(constDecl.value) - astForConstOrFieldValue(stmt, name, code, someValue, scope.addConstOrStaticInitToScope, isField = false) + astForConstOrStaticOrFieldValue(stmt, name, code, someValue, scope.addConstOrStaticInitToScope, isField = false) .withChildren(modifierAsts) } } @@ -889,21 +939,35 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val name = stmt.name.name val code = s"case $name" - astForConstOrFieldValue(stmt, name, code, stmt.expr, scope.addConstOrStaticInitToScope, isField = false) + astForConstOrStaticOrFieldValue(stmt, name, code, stmt.expr, scope.addConstOrStaticInitToScope, isField = false) .withChild(finalModifier) } private def astsForPropertyStmt(stmt: PhpPropertyStmt): List[Ast] = { stmt.variables.map { varDecl => - val modifierAsts = stmt.modifiers.map(newModifierNode).map(Ast(_)) + val modifiers = stmt.modifiers + val modifierAsts = modifiers.map(newModifierNode).map(Ast(_)) val name = varDecl.name.name - astForConstOrFieldValue(stmt, name, s"$$$name", varDecl.defaultValue, scope.addFieldInitToScope, isField = true) - .withChildren(modifierAsts) + val ast = if (modifiers.contains(ModifierTypes.STATIC)) { + // A static member belongs to a class, not an instance + val memberCode = s"static $$$name" + astForConstOrStaticOrFieldValue( + stmt, + name, + memberCode, + varDecl.defaultValue, + scope.addConstOrStaticInitToScope, + false + ) + } else + astForConstOrStaticOrFieldValue(stmt, name, s"$$$name", varDecl.defaultValue, scope.addFieldInitToScope, true) + + ast.withChildren(modifierAsts) } } - private def astForConstOrFieldValue( + private def astForConstOrStaticOrFieldValue( originNode: PhpNode, name: String, code: String, @@ -1087,7 +1151,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], .sortBy(_.argumentIndex) if (args.size != 2) { - val position = s"${line(assignment).getOrElse("")}:$filename" + val position = s"${line(assignment).getOrElse("")}:$relativeFileName" logger.warn(s"Expected 2 call args for emptyArrayDimAssign. Not resetting code: $position") } else { val codeOverride = s"${args.head.code}[] = ${args.last.code}" @@ -1097,12 +1161,103 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], arrayPushAst } + /** Lower the array/list unpack. For example `[$a, $b] = $arr;` will be lowered to `$a = $arr[0]; $b = $arr[1];` + */ + private def astForArrayUnpack(assignment: PhpAssignment, target: PhpArrayExpr | PhpListExpr): Ast = { + val loweredAssignNodes = mutable.ListBuffer.empty[Ast] + + // create a Identifier ast for given name + def createIdentifier(name: String): Ast = Ast(identifierNode(assignment, name, s"$$$name", TypeConstants.Any)) + + def createIndexAccessChain( + targetAst: Ast, + sourceAst: Ast, + idxTracker: ArrayIndexTracker, + item: PhpArrayItem + ): Ast = { + // copy from `assignForArrayItem` to handle the case where key exists, such as `list("id" => $a, "name" => $b) = $arr;` + val dimension = item.key match { + case Some(key: PhpSimpleScalar) => dimensionFromSimpleScalar(key, idxTracker) + case Some(key) => key + case None => PhpInt(idxTracker.next, item.attributes) + } + val dimensionAst = astForExpr(dimension) + val indexAccessCode = s"${sourceAst.rootCodeOrEmpty}[${dimensionAst.rootCodeOrEmpty}]" + // .indexAccess(sourceAst, index) + val indexAccessNode = callAst( + newOperatorCallNode(Operators.indexAccess, indexAccessCode, line = line(item)), + sourceAst :: dimensionAst :: Nil + ) + val assignCode = s"${targetAst.rootCodeOrEmpty} = $indexAccessCode" + val assignNode = newOperatorCallNode(Operators.assignment, assignCode, line = line(item)) + // targetAst = .indexAccess(sourceAst, index) + callAst(assignNode, targetAst :: indexAccessNode :: Nil) + } + + // Take `[[$a, $b], $c] = $arr;` as an example + def handleUnpackLowering( + target: PhpArrayExpr | PhpListExpr, + itemsOf: PhpArrayExpr | PhpListExpr => List[Option[PhpArrayItem]], + sourceAst: Ast + ): Unit = { + val idxTracker = new ArrayIndexTracker + + // create an alias identifier of $arr + val sourceAliasName = getNewTmpName() + val sourceAliasIdentifier = createIdentifier(sourceAliasName) + val assignCode = s"${sourceAliasIdentifier.rootCodeOrEmpty} = ${sourceAst.rootCodeOrEmpty}" + val assignNode = newOperatorCallNode(Operators.assignment, assignCode, line = line(assignment)) + loweredAssignNodes += callAst(assignNode, sourceAliasIdentifier :: sourceAst :: Nil) + + itemsOf(target).foreach { + case Some(item) => + item.value match { + case nested: (PhpArrayExpr | PhpListExpr) => // item is [$a, $b] + // create tmp variable for [$a, $b] to receive the result of .indexAccess($arr, 0) + val tmpIdentifierName = getNewTmpName() + // tmpVar = .indexAccess($arr, 0) + val targetAssignNode = + createIndexAccessChain( + createIdentifier(tmpIdentifierName), + createIdentifier(sourceAliasName), + idxTracker, + item + ) + loweredAssignNodes += targetAssignNode + handleUnpackLowering(nested, itemsOf, createIdentifier(tmpIdentifierName)) + case phpVar: PhpVariable => // item is $c + val identifier = astForExpr(phpVar) + // $c = .indexAccess($arr, 1) + val targetAssignNode = + createIndexAccessChain(identifier, createIdentifier(sourceAliasName), idxTracker, item) + loweredAssignNodes += targetAssignNode + case _ => + // unknown case + idxTracker.next + } + case None => + idxTracker.next + } + } + + val sourceAst = astForExpr(assignment.source) + val itemsOf = (exp: PhpArrayExpr | PhpListExpr) => + exp match { + case x: PhpArrayExpr => x.items + case x: PhpListExpr => x.items + } + handleUnpackLowering(target, itemsOf, sourceAst) + Ast(blockNode(assignment)) + .withChildren(loweredAssignNodes.toList) + } + private def astForAssignment(assignment: PhpAssignment): Ast = { assignment.target match { case arrayDimFetch: PhpArrayDimFetchExpr if arrayDimFetch.dimension.isEmpty => // Rewrite `$xs[] = ` as `array_push($xs, )` to simplify finding dataflows. astForEmptyArrayDimAssign(assignment, arrayDimFetch) - + case arrayExpr: (PhpArrayExpr | PhpListExpr) => + astForArrayUnpack(assignment, arrayExpr) case _ => val operatorName = assignment.assignOp @@ -1293,10 +1448,30 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], private def astForArrayExpr(expr: PhpArrayExpr): Ast = { val idxTracker = new ArrayIndexTracker - val tmpIdentifier = getTmpIdentifier(expr, Some(TypeConstants.Array)) + val tmpName = getNewTmpName() + + def newTmpIdentifier: Ast = Ast(identifierNode(expr, tmpName, s"$$$tmpName", TypeConstants.Array)) + + val tmpIdentifierAssignNode = { + // use array() function to create an empty array. see https://www.php.net/manual/zh/function.array.php + val initArrayNode = callNode( + expr, + "array()", + "array", + "array", + DispatchTypes.STATIC_DISPATCH, + Some("array()"), + Some(TypeConstants.Array) + ) + val initArrayCallAst = callAst(initArrayNode) + + val assignCode = s"$$$tmpName = ${initArrayCallAst.rootCodeOrEmpty}" + val assignNode = newOperatorCallNode(Operators.assignment, assignCode, line = line(expr)) + callAst(assignNode, newTmpIdentifier :: initArrayCallAst :: Nil) + } val itemAssignments = expr.items.flatMap { - case Some(item) => Option(assignForArrayItem(item, tmpIdentifier.name, idxTracker)) + case Some(item) => Option(assignForArrayItem(item, tmpName, idxTracker)) case None => idxTracker.next // Skip an index None @@ -1304,8 +1479,9 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], val arrayBlock = blockNode(expr) Ast(arrayBlock) + .withChild(tmpIdentifierAssignNode) .withChildren(itemAssignments) - .withChild(Ast(tmpIdentifier)) + .withChild(newTmpIdentifier) } private def astForListExpr(expr: PhpListExpr): Ast = { @@ -1419,14 +1595,14 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], Some(localNode(closureExpr, name, s"$byRefPrefix$$$name", typeFullName)) case other => - logger.warn(s"Found incorrect closure use variable '$other' in $filename") + logger.warn(s"Found incorrect closure use variable '$other' in $relativeFileName") None } } // Add closure bindings to diffgraph localsForUses.foreach { local => - val closureBindingId = s"$filename:$methodName:${local.name}" + val closureBindingId = s"$relativeFileName:$methodName:${local.name}" local.closureBindingId(closureBindingId) scope.addToScope(local.name, local) @@ -1607,7 +1783,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], callAst(accessNode, variableAst :: dimensionAst :: Nil) case None => - val errorPosition = s"$variableCode:${line(expr).getOrElse("")}:$filename" + val errorPosition = s"$variableCode:${line(expr).getOrElse("")}:$relativeFileName" logger.error(s"ArrayDimFetchExpr without dimensions should be handled in assignment: $errorPosition") Ast() } @@ -1681,7 +1857,7 @@ class AstCreator(filename: String, phpAst: PhpFile, fileContent: Option[String], .getOrElse(nameExpr.name) case expr => - logger.warn(s"Unexpected expression as class name in ::class expression: $filename") + logger.warn(s"Unexpected expression as class name in ::class expression: $relativeFileName") NameConstants.Unknown } diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/ClassParser.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/ClassParser.scala index 0ac7a485d9b0..7110da272d71 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/ClassParser.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/ClassParser.scala @@ -3,7 +3,6 @@ import better.files.File import io.joern.x2cpg.utils.ExternalCommand import org.slf4j.LoggerFactory -import scala.collection.immutable.LazyList.from import scala.io.Source import scala.util.{Failure, Success, Try, Using} import upickle.default.* @@ -26,11 +25,12 @@ class ClassParser(targetDir: File) { f } - private lazy val phpClassParseCommand: String = s"php ${classParserScript.pathAsString} ${targetDir.pathAsString}" + private lazy val phpClassParseCommand: Seq[String] = + Seq("php", classParserScript.pathAsString, targetDir.pathAsString) def parse(): Try[List[ClassParserClass]] = Try { val inputDirectory = targetDir.parent.canonicalPath - ExternalCommand.run(phpClassParseCommand, inputDirectory).map(_.reverse) match { + ExternalCommand.run(phpClassParseCommand, inputDirectory).toTry.map(_.reverse) match { case Success(output) => read[List[ClassParserClass]](output.mkString("\n")) case Failure(exception) => diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/PhpParser.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/PhpParser.scala index 45017300cafa..9a1a72bb28a1 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/PhpParser.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/parser/PhpParser.scala @@ -6,7 +6,10 @@ import io.joern.php2cpg.parser.Domain.PhpFile import io.joern.x2cpg.utils.ExternalCommand import org.slf4j.LoggerFactory +import java.nio.file.Path import java.nio.file.Paths +import java.util.regex.Pattern +import scala.collection.mutable import scala.io.Source import scala.util.{Failure, Success, Try} @@ -14,56 +17,144 @@ class PhpParser private (phpParserPath: String, phpIniPath: String, disableFileC private val logger = LoggerFactory.getLogger(this.getClass) - private def phpParseCommand(filename: String): String = { - val phpParserCommands = "--with-recovery --resolve-names --json-dump" - s"php --php-ini $phpIniPath $phpParserPath $phpParserCommands $filename" + private def phpParseCommand(filenames: collection.Seq[String]): Seq[String] = { + val phpParserCommands = Seq("--with-recovery", "--resolve-names", "--json-dump") + Seq("php", "--php-ini", phpIniPath, phpParserPath) ++ phpParserCommands ++ filenames } - def parseFile(inputPath: String): Option[(PhpFile, Option[String])] = { - val inputFile = File(inputPath) - val inputFilePath = inputFile.canonicalPath - val inputDirectory = inputFile.parent.canonicalPath - val command = phpParseCommand(inputFilePath) - ExternalCommand.run(command, inputDirectory) match { - case Success(output) => - val content = Option.unless(disableFileContent)(inputFile.contentAsString) - processParserOutput(output, inputFilePath).map((_, content)) - case Failure(exception) => - logger.error(s"Failure running php-parser with $command", exception.getMessage) - None + def parseFiles(inputPaths: collection.Seq[String]): collection.Seq[(String, Option[PhpFile], String)] = { + // We need to keep a map between the input path and its canonical representation in + // order to map back the canonical file name we get from the php parser. + // Otherwise later on file name/path processing might get confused because the returned + // file paths are in no relation to the input paths. + val canonicalToInputPath = mutable.HashMap.empty[String, String] + + inputPaths.foreach { inputPath => + val canonicalPath = Path.of(inputPath).toFile.getCanonicalPath + canonicalToInputPath.put(canonicalPath, inputPath) } - } - private def processParserOutput(output: Seq[String], filename: String): Option[PhpFile] = { - val maybeJson = linesToJsonValue(output, filename) - maybeJson.flatMap(jsonValueToPhpFile(_, filename)) + val command = phpParseCommand(inputPaths) + + val result = ExternalCommand.run(command, ".", mergeStdErrInStdOut = true) + result match { + case ExternalCommand.ExternalCommandResult(0, stdOut, _) => + val asJson = linesToJsonValues(stdOut) + val asPhpFile = asJson.map { case (filename, jsonObjectOption, infoLines) => + (filename, jsonToPhpFile(jsonObjectOption, filename), infoLines) + } + val withRemappedFileName = asPhpFile.map { case (filename, phpFileOption, infoLines) => + (canonicalToInputPath.apply(filename), phpFileOption, infoLines) + } + withRemappedFileName + case ExternalCommand.ExternalCommandResult(exitCode, _, _) => + logger.error(s"Failure running php-parser with ${command.mkString(" ")}, exit code $exitCode") + Nil + } } - private def linesToJsonValue(lines: Seq[String], filename: String): Option[ujson.Value] = { - if (lines.exists(_.startsWith("["))) { - val jsonString = lines.dropWhile(_.charAt(0) != '[').mkString - Try(Option(ujson.read(jsonString))) match { - case Success(Some(value)) => Some(value) - case Success(None) => - logger.error(s"Parsing json string for $filename resulted in null return value") - None - case Failure(exception) => - logger.error(s"Parsing json string for $filename failed with exception", exception) + private def jsonToPhpFile(jsonObject: Option[ujson.Value], filename: String): Option[PhpFile] = { + val phpFile = jsonObject.flatMap { jsonObject => + Try(Domain.fromJson(jsonObject)) match { + case Success(phpFile) => + Some(phpFile) + case Failure(e) => + logger.error(s"Failed to generate intermediate AST for $filename", e) None } - } else { - logger.warn(s"No JSON output for $filename") - None } + phpFile } - private def jsonValueToPhpFile(json: ujson.Value, filename: String): Option[PhpFile] = { - Try(Domain.fromJson(json)) match { - case Success(phpFile) => Some(phpFile) - case Failure(e) => - logger.error(s"Failed to generate intermediate AST for $filename", e) - None + enum PARSE_MODE { + case PARSE_INFO, PARSE_JSON, SKIP_TRAILER, SKIP_WARNING + } + + private def getJsonResult( + filename: String, + jsonLines: Array[String], + infoLines: Array[String] + ): collection.Seq[(String, Option[ujson.Value], String)] = { + val result = mutable.ArrayBuffer.empty[(String, Option[ujson.Value], String)] + + val jsonString = jsonLines.mkString + + Try(Option(ujson.read(jsonString))) match { + case Success(option) => + result.append((filename, option, infoLines.mkString)) + if (option.isEmpty) { + logger.error(s"Parsing json string for $filename resulted in null return value") + } + case Failure(exception) => + result.append((filename, None, infoLines.mkString)) + logger.error(s"Parsing json string for $filename failed with exception", exception) + } + + result + } + + private def logWarning(lines: collection.Seq[String]): Unit = { + if (lines.exists(_.nonEmpty)) { + logger.warn(s"Found warning in PHP-Parser JSON output:\n${lines.mkString("\n")}") + } + } + + private def linesToJsonValues( + lines: collection.Seq[String] + ): collection.Seq[(String, Option[ujson.Value], String)] = { + val filePrefix = "====> File " + val filenameRegex = Pattern.compile(s"$filePrefix(.*):") + val result = mutable.ArrayBuffer.empty[(String, Option[ujson.Value], String)] + + var filename = "" + val infoLines = mutable.ArrayBuffer.empty[String] + val jsonLines = mutable.ArrayBuffer.empty[String] + val warningLines = mutable.ArrayBuffer.empty[String] + + var mode = PARSE_MODE.SKIP_TRAILER + val linesIt = lines.iterator + while (linesIt.hasNext) { + val line = linesIt.next + mode match { + case PARSE_MODE.PARSE_INFO => + if (line != "==> JSON dump:") { + infoLines.append(line) + } else { + mode = PARSE_MODE.SKIP_WARNING + } + case PARSE_MODE.SKIP_WARNING => + if (line == "[]") { + logWarning(warningLines) + jsonLines.append(line) + result.appendAll(getJsonResult(filename, jsonLines.toArray, infoLines.toArray)) + mode = PARSE_MODE.SKIP_TRAILER + } else if (line.startsWith("[")) { + logWarning(warningLines) + jsonLines.append(line) + mode = PARSE_MODE.PARSE_JSON + } else { + warningLines.append(line) + } + case PARSE_MODE.PARSE_JSON => + jsonLines.append(line) + if (line.startsWith("]") || line == "[]") { + result.appendAll(getJsonResult(filename, jsonLines.toArray, infoLines.toArray)) + mode = PARSE_MODE.SKIP_TRAILER + } + case _ => + } + + if (line.startsWith(filePrefix)) { + val matcher = filenameRegex.matcher(line) + if (matcher.find()) { + filename = matcher.group(1) + infoLines.clear() + jsonLines.clear() + mode = PARSE_MODE.PARSE_INFO + } + } } + result } } @@ -79,9 +170,11 @@ object PhpParser { } private def defaultPhpParserBin: String = { - val dir = Paths.get(this.getClass.getProtectionDomain.getCodeSource.getLocation.toURI).toAbsolutePath.toString - val fixedDir = new java.io.File(dir.substring(0, dir.indexOf("php2cpg"))).toString - Paths.get(fixedDir, "php2cpg", "bin", "php-parser", "php-parser.php").toAbsolutePath.toString + val packagePath = Paths.get(this.getClass.getProtectionDomain.getCodeSource.getLocation.toURI) + ExternalCommand + .executableDir(packagePath) + .resolve("php-parser/php-parser.php") + .toString } private def configOverrideOrDefaultPath( diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AnyTypePass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AnyTypePass.scala index 8f9c618a3333..c65478137eb0 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AnyTypePass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AnyTypePass.scala @@ -6,14 +6,14 @@ import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.codepropertygraph.generated.nodes.Call.PropertyDefaults import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* // TODO This is a hack for a customer issue. Either extend this to handle type full names properly, // or do it elsewhere. class AnyTypePass(cpg: Cpg) extends ForkJoinParallelCpgPass[AstNode](cpg) { override def generateParts(): Array[AstNode] = { - cpg.has(PropertyNames.TYPE_FULL_NAME, PropertyDefaults.TypeFullName).collectAll[AstNode].toArray + cpg.graph.nodesWithProperty(PropertyNames.TYPE_FULL_NAME, PropertyDefaults.TypeFullName).collectAll[AstNode].toArray } override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = { diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstCreationPass.scala index 7d8e3f0c15f5..01d80874a7cc 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstCreationPass.scala @@ -12,36 +12,53 @@ import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.* class AstCreationPass(config: Config, cpg: Cpg, parser: PhpParser)(implicit withSchemaValidation: ValidationMode) - extends ForkJoinParallelCpgPass[String](cpg) { + extends ForkJoinParallelCpgPass[Array[String]](cpg) { private val logger = LoggerFactory.getLogger(this.getClass) val PhpSourceFileExtensions: Set[String] = Set(".php") - override def generateParts(): Array[String] = SourceFiles - .determine( - config.inputPath, - PhpSourceFileExtensions, - ignoredFilesRegex = Option(config.ignoredFilesRegex), - ignoredFilesPath = Option(config.ignoredFiles) - ) - .toArray - - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = { - val relativeFilename = if (filename == config.inputPath) { - File(filename).name - } else { - File(config.inputPath).relativize(File(filename)).toString - } - parser.parseFile(filename) match { - case Some((parseResult, fileContent)) => - diffGraph.absorb( - new AstCreator(relativeFilename, parseResult, fileContent, config.disableFileContent)(config.schemaValidation) - .createAst() - ) - - case None => - logger.warn(s"Could not parse file $filename. Results will be missing!") + override def generateParts(): Array[Array[String]] = { + val sourceFiles = SourceFiles + .determine( + config.inputPath, + PhpSourceFileExtensions, + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + .toArray + + // We need to feed the php parser big groups of file in order + // to speed up the parsing. Apparently it is some sort of slow + // startup phase which makes single file processing prohibitively + // slow. + // On the other hand we need to be careful to not choose too big + // chunks because: + // 1. The argument length to the php executable has system + // dependent limits + // 2. We want to make use of multiple CPU cores for the rest + // of the CPG creation. + // + val parts = sourceFiles.grouped(20).toArray + parts + } + + override def runOnPart(diffGraph: DiffGraphBuilder, filenames: Array[String]): Unit = { + parser.parseFiles(filenames).foreach { case (filename, parseResult, infoLines) => + parseResult match { + case Some(parseResult) => + val relativeFilename = if (filename == config.inputPath) { + File(filename).name + } else { + File(config.inputPath).relativize(File(filename)).toString + } + diffGraph.absorb( + new AstCreator(relativeFilename, filename, parseResult, config.disableFileContent)(config.schemaValidation) + .createAst() + ) + case None => + logger.warn(s"Could not parse file $filename. Results will be missing!") + } } } } diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstParentInfoPass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstParentInfoPass.scala index 8594efb2a949..701b7523d033 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstParentInfoPass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/AstParentInfoPass.scala @@ -1,10 +1,9 @@ package io.joern.php2cpg.passes -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.PropertyNames +import io.shiftleft.codepropertygraph.generated.{Cpg, Properties, PropertyNames} import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, NamespaceBlock, Method, TypeDecl} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AstParentInfoPass(cpg: Cpg) extends ForkJoinParallelCpgPass[AstNode](cpg) { @@ -15,7 +14,7 @@ class AstParentInfoPass(cpg: Cpg) extends ForkJoinParallelCpgPass[AstNode](cpg) override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = { findParent(node).foreach { parentNode => val astParentType = parentNode.label - val astParentFullName = parentNode.property(PropertyNames.FULL_NAME) + val astParentFullName = parentNode.property(Properties.FullName) diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_TYPE, astParentType) diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_FULL_NAME, astParentFullName) diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/ClosureRefPass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/ClosureRefPass.scala index ac790b9fb438..c5c2c289a54c 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/ClosureRefPass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/ClosureRefPass.scala @@ -3,7 +3,7 @@ package io.joern.php2cpg.passes import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{ClosureBinding, Method, MethodRef} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.AstNode @@ -22,7 +22,7 @@ class ClosureRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[ClosureBinding](c * that is the scope in which the closure would have originally been created. */ override def runOnPart(diffGraph: DiffGraphBuilder, closureBinding: ClosureBinding): Unit = { - closureBinding.captureIn.collectAll[MethodRef].toList match { + closureBinding._methodRefViaCaptureIn.toList match { case Nil => logger.error(s"No MethodRef corresponding to closureBinding ${closureBinding.closureBindingId}") diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/DependencyPass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/DependencyPass.scala index 003efcc8d33c..8b56de0fa15e 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/DependencyPass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/DependencyPass.scala @@ -6,7 +6,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewDependency, NewTag} import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes} import io.shiftleft.passes.ForkJoinParallelCpgPass import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate import upickle.default.* import scala.annotation.targetName diff --git a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/LocalCreationPass.scala b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/LocalCreationPass.scala index 89ed00ffe137..2e113b80ae0a 100644 --- a/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/LocalCreationPass.scala +++ b/joern-cli/frontends/php2cpg/src/main/scala/io/joern/php2cpg/passes/LocalCreationPass.scala @@ -13,7 +13,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewNode, TypeDecl } -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.php2cpg.astcreation.AstCreator import io.joern.php2cpg.parser.Domain import io.joern.php2cpg.parser.Domain.PhpOperators @@ -96,7 +96,7 @@ abstract class LocalCreationPass[ScopeType <: AstNode](cpg: Cpg) ): Unit = { val identifierMap = getIdentifiersInScope(bodyNode) - .filter(_.refOut.isEmpty) + .filter(_._refOut.isEmpty) .filterNot(excludeIdentifierFn) .groupBy(_.name) diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/config/ConfigTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/config/ConfigTests.scala index 12ff84b48d26..4174ab10d63f 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/dataflow/IntraMethodDataflowTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/dataflow/IntraMethodDataflowTests.scala index 737898f2bc05..7d680770be71 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/dataflow/IntraMethodDataflowTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/dataflow/IntraMethodDataflowTests.scala @@ -1,8 +1,8 @@ package io.joern.php2cpg.dataflow import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* class IntraMethodDataflowTests extends PhpCode2CpgFixture(runOssDataflow = true) { "flows from parameters to corresponding identifiers should be found" in { @@ -29,4 +29,50 @@ class IntraMethodDataflowTests extends PhpCode2CpgFixture(runOssDataflow = true) flows.size shouldBe 1 } + + "flow from single layer array unpacking should be found" in { + val cpg = code(""" $value) { + | echo $key; + | echo $value; + |} + |""".stripMargin) + val source = cpg.identifier("arr") + val sink = cpg.call("echo").argument(1) + val flows = sink.reachableByFlows(source) + flows.size shouldBe 2 + } } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/io/Php2CpgHTTPServerTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/io/Php2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..ddadfc9d0622 --- /dev/null +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/io/Php2CpgHTTPServerTests.scala @@ -0,0 +1,79 @@ +package io.joern.php2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class Php2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("php2cpgTestsHttpTest") + val file = dir / "main.php" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse(""""Hello, World!"""") + file.writeText(s""" fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.call.code.l shouldBe List("""print("Hello, World!")""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("php2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.call.code.l shouldBe List(s"print($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/CfgCreationPassTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/CfgCreationPassTests.scala index a4d244a01c48..b602b6c82784 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/CfgCreationPassTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/CfgCreationPassTests.scala @@ -29,7 +29,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("break(1)") shouldBe expected(("$i", AlwaysEdge)) + succOf("break(1)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } "be correct for break with level 2" in { @@ -40,7 +40,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("break(2)") shouldBe expected(("RET", AlwaysEdge)) + succOf("break(2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for continue with level 1" in { @@ -51,7 +51,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("continue(1)") shouldBe expected(("$j", AlwaysEdge)) + succOf("continue(1)") should contain theSameElementsAs expected(("$j", AlwaysEdge)) } "be correct for continue with level 2" in { @@ -62,7 +62,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("continue(2)") shouldBe expected(("$i", AlwaysEdge)) + succOf("continue(2)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } } @@ -75,7 +75,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } while ($j < 1); |} while ($i < 1) |""".stripMargin) - succOf("break(1)") shouldBe expected(("$i", AlwaysEdge)) + succOf("break(1)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } "be correct for break with level 2" in { @@ -86,7 +86,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } while ($j < 1); |} while ($i < 1) |""".stripMargin) - succOf("break(2)") shouldBe expected(("RET", AlwaysEdge)) + succOf("break(2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for continue with level 1" in { @@ -97,7 +97,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } while ($j < 1); |} while ($i < 1) |""".stripMargin) - succOf("continue(1)") shouldBe expected(("$j", AlwaysEdge)) + succOf("continue(1)") should contain theSameElementsAs expected(("$j", AlwaysEdge)) } "be correct for continue with level 2" in { @@ -108,7 +108,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } while ($j < 1); |} while ($i < 1) |""".stripMargin) - succOf("continue(2)") shouldBe expected(("$i", AlwaysEdge)) + succOf("continue(2)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } } @@ -121,7 +121,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("break(1)") shouldBe expected(("$i", AlwaysEdge)) + succOf("break(1)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } "be correct for break with level 2" in { @@ -132,7 +132,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("break(2)") shouldBe expected(("RET", AlwaysEdge)) + succOf("break(2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } "be correct for continue with level 1" in { @@ -143,7 +143,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("continue(1)") shouldBe expected(("$j", AlwaysEdge)) + succOf("continue(1)") should contain theSameElementsAs expected(("$j", AlwaysEdge)) } "be correct for continue with level 2" in { @@ -154,7 +154,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | } |} |""".stripMargin) - succOf("continue(2)") shouldBe expected(("$i", AlwaysEdge)) + succOf("continue(2)") should contain theSameElementsAs expected(("$i", AlwaysEdge)) } } @@ -170,7 +170,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | $k; |} |""".stripMargin) - succOf("break(1)") shouldBe expected(("$k", AlwaysEdge)) + succOf("break(1)") should contain theSameElementsAs expected(("$k", AlwaysEdge)) } "be correct for break with level 2" in { @@ -184,7 +184,7 @@ class CfgCreationPassTests extends CfgTestFixture(() => new PhpCfgTestCpg) { | $k; |} |""".stripMargin) - succOf("break(2)") shouldBe expected(("RET", AlwaysEdge)) + succOf("break(2)") should contain theSameElementsAs expected(("RET", AlwaysEdge)) } } } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/PhpTypeRecoveryPassTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/PhpTypeRecoveryPassTests.scala index d5d19ab7c5bd..0384627d1110 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/PhpTypeRecoveryPassTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/passes/PhpTypeRecoveryPassTests.scala @@ -525,7 +525,7 @@ class PhpTypeRecoveryPassTests extends PhpCode2CpgFixture() { "propagate this QueryBuilder type to the identifier assigned to the inherited call for the wrapped `createQueryBuilder`" in { cpg.method .nameExact("findSomething") - ._containsOut + .containsOut .collectAll[Identifier] .nameExact("queryBuilder") .typeFullName diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ArrayTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ArrayTests.scala index 7e671a914683..bf213794b2c6 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ArrayTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ArrayTests.scala @@ -3,7 +3,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal, Local} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ArrayTests extends PhpCode2CpgFixture { "array accesses with variable keys should be represented as index accesses" in { @@ -77,16 +77,20 @@ class ArrayTests extends PhpCode2CpgFixture { tmpLocal.name shouldBe "tmp0" tmpLocal.code shouldBe "$tmp0" - inside(arrayBlock.astChildren.l) { case List(aAssign: Call, bAssign: Call, tmpIdent: Identifier) => - aAssign.code shouldBe "$tmp0[\"A\"] = 1" - aAssign.lineNumber shouldBe Some(3) + inside(arrayBlock.astChildren.l) { + case List(initAssign: Call, aAssign: Call, bAssign: Call, tmpIdent: Identifier) => + initAssign.code shouldBe "$tmp0 = array()" + initAssign.lineNumber shouldBe Some(2) - bAssign.code shouldBe "$tmp0[\"B\"] = 2" - bAssign.lineNumber shouldBe Some(4) + aAssign.code shouldBe "$tmp0[\"A\"] = 1" + aAssign.lineNumber shouldBe Some(3) - tmpIdent.name shouldBe "tmp0" - tmpIdent.code shouldBe "$tmp0" - tmpIdent._localViaRefOut should contain(tmpLocal) + bAssign.code shouldBe "$tmp0[\"B\"] = 2" + bAssign.lineNumber shouldBe Some(4) + + tmpIdent.name shouldBe "tmp0" + tmpIdent.code shouldBe "$tmp0" + tmpIdent._localViaRefOut should contain(tmpLocal) } } } @@ -103,16 +107,20 @@ class ArrayTests extends PhpCode2CpgFixture { tmpLocal.name shouldBe "tmp0" tmpLocal.code shouldBe "$tmp0" - inside(arrayBlock.astChildren.l) { case List(aAssign: Call, bAssign: Call, tmpIdent: Identifier) => - aAssign.code shouldBe "$tmp0[0] = \"A\"" - aAssign.lineNumber shouldBe Some(3) + inside(arrayBlock.astChildren.l) { + case List(initAssign: Call, aAssign: Call, bAssign: Call, tmpIdent: Identifier) => + initAssign.code shouldBe "$tmp0 = array()" + initAssign.lineNumber shouldBe Some(2) + + aAssign.code shouldBe "$tmp0[0] = \"A\"" + aAssign.lineNumber shouldBe Some(3) - bAssign.code shouldBe "$tmp0[1] = \"B\"" - bAssign.lineNumber shouldBe Some(4) + bAssign.code shouldBe "$tmp0[1] = \"B\"" + bAssign.lineNumber shouldBe Some(4) - tmpIdent.name shouldBe "tmp0" - tmpIdent.code shouldBe "$tmp0" - tmpIdent._localViaRefOut should contain(tmpLocal) + tmpIdent.name shouldBe "tmp0" + tmpIdent.code shouldBe "$tmp0" + tmpIdent._localViaRefOut should contain(tmpLocal) } } } @@ -128,7 +136,10 @@ class ArrayTests extends PhpCode2CpgFixture { tmpLocal.name shouldBe "tmp0" tmpLocal.code shouldBe "$tmp0" - inside(arrayBlock.astChildren.l) { case List(assign: Call, tmpIdent: Identifier) => + inside(arrayBlock.astChildren.l) { case List(initAssign: Call, assign: Call, tmpIdent: Identifier) => + initAssign.code shouldBe "$tmp0 = array()" + initAssign.lineNumber shouldBe Some(2) + assign.code shouldBe "$tmp0[2] = \"A\"" inside(assign.argument.collectAll[Call].argument.l) { case List(array: Identifier, index: Literal) => array.name shouldBe "tmp0" @@ -164,6 +175,7 @@ class ArrayTests extends PhpCode2CpgFixture { inside(arrayBlock.astChildren.l) { case List( + initAssign: Call, aAssign: Call, cAssign: Call, fourAssign: Call, @@ -173,6 +185,9 @@ class ArrayTests extends PhpCode2CpgFixture { eightAssign: Call, tmpIdent: Identifier ) => + initAssign.code shouldBe "$tmp0 = array()" + initAssign.lineNumber shouldBe Some(2) + aAssign.code shouldBe "$tmp0[\"A\"] = \"B\"" cAssign.code shouldBe "$tmp0[0] = \"C\"" fourAssign.code shouldBe "$tmp0[4] = \"D\"" diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CallTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CallTests.scala index 9e554bb19725..7297935b2db3 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CallTests.scala @@ -5,7 +5,7 @@ import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CallTests extends PhpCode2CpgFixture { "variable call arguments with names matching methods should not have a methodref" in { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CfgTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CfgTests.scala index c630338e7794..49f9c79a2708 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CfgTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CfgTests.scala @@ -3,7 +3,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.parser.Domain.PhpOperators import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Call, JumpTarget} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CfgTests extends PhpCode2CpgFixture { "the CFG for match constructs" when { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CommentTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CommentTests.scala index c7b490ff6922..26728291ee45 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CommentTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/CommentTests.scala @@ -1,7 +1,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CommentTests extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ControlStructureTests.scala index c67fe7bbc2a4..419d1d83d446 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ControlStructureTests.scala @@ -14,7 +14,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ Literal, Local } -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.AstNode import scala.util.Try @@ -1261,17 +1261,24 @@ class ControlStructureTests extends PhpCode2CpgFixture { (initAsts, updateAsts, body) } - inside(initAsts.astChildren.l) { case List(_: Call, valInit: Call) => - valInit.name shouldBe Operators.assignment - valInit.code shouldBe "$key => $val = $iter_tmp0->current()" - inside(valInit.argument.l) { case List(valPair: Call, currentCall: Call) => - valPair.name shouldBe PhpOperators.doubleArrow - valPair.code shouldBe "$key => $val" - inside(valPair.argument.l) { case List(keyId: Identifier, valId: Identifier) => - keyId.name shouldBe "key" - valId.name shouldBe "val" + inside(initAsts.assignment.l) { case List(_: Call, keyInit: Call, valInit: Call) => + keyInit.name shouldBe Operators.assignment + keyInit.code shouldBe "$key = $iter_tmp0->key()" + inside(keyInit.argument.l) { case List(target: Identifier, keyCall: Call) => + target.name shouldBe "key" + keyCall.name shouldBe "key" + keyCall.methodFullName shouldBe s"Iterator.key" + keyCall.code shouldBe "$iter_tmp0->key()" + inside(keyCall.argument(0).start.l) { case List(iterRecv: Identifier) => + iterRecv.name shouldBe "iter_tmp0" + iterRecv.argumentIndex shouldBe 0 } + } + valInit.name shouldBe Operators.assignment + valInit.code shouldBe "$val = $iter_tmp0->current()" + inside(valInit.argument.l) { case List(target: Identifier, currentCall: Call) => + target.name shouldBe "val" currentCall.name shouldBe "current" currentCall.methodFullName shouldBe s"Iterator.current" currentCall.code shouldBe "$iter_tmp0->current()" @@ -1282,9 +1289,12 @@ class ControlStructureTests extends PhpCode2CpgFixture { } } - inside(updateAsts.astChildren.l) { case List(_: Call, valAssign: Call) => - valAssign.name shouldBe Operators.assignment - valAssign.code shouldBe "$key => $val = $iter_tmp0->current()" + inside(updateAsts.astChildren.l) { case List(_: Call, updateBlock: Block) => + val tmp = updateBlock.astChildren.l + inside(updateBlock.assignment.l) { case List(keyInit: Call, valInit: Call) => + keyInit.code shouldBe "$key = $iter_tmp0->key()" + valInit.code shouldBe "$val = $iter_tmp0->current()" + } } inside(body.astChildren.l) { case List(echoCall: Call) => diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/FieldAccessTests.scala index c016b8c0fadb..3925ebd2d799 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/FieldAccessTests.scala @@ -3,7 +3,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{FieldIdentifier, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/LocalTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/LocalTests.scala index 322bd1a5a79d..08a730754cc0 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/LocalTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/LocalTests.scala @@ -1,7 +1,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class LocalTests extends PhpCode2CpgFixture { @@ -71,7 +71,6 @@ class LocalTests extends PhpCode2CpgFixture { | } |} |""".stripMargin) - println(cpg.local.name.l) inside(cpg.local.l) { case List(xLocal) => xLocal.name shouldBe "x" xLocal.code shouldBe "static $x" diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/MemberTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/MemberTests.scala index 087fb9e578c8..f6ff1585ab70 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/MemberTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/MemberTests.scala @@ -1,21 +1,24 @@ package io.joern.php2cpg.querying +import io.joern.php2cpg.Config import io.joern.php2cpg.parser.Domain import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.{ModifierTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal, Local} +import io.shiftleft.semanticcpg.language.* class MemberTests extends PhpCode2CpgFixture { "class constants" should { - val cpg = code(""" @@ -40,11 +43,20 @@ class MemberTests extends PhpCode2CpgFixture { "have a clinit method with the constant initializers" in { inside(cpg.method.nameExact(Defines.StaticInitMethodName).l) { case List(clinitMethod) => - inside(clinitMethod.body.astChildren.l) { case List(aAssign: Call, bAssign: Call, cAssign: Call) => + inside(clinitMethod.body.astChildren.l) { case List(self: Local, aAssign: Call, bAssign: Call, cAssign: Call) => + self.name shouldBe "self" checkConstAssign(aAssign, "A") checkConstAssign(bAssign, "B") checkConstAssign(cAssign, "C") } + clinitMethod.isExternal shouldBe false + clinitMethod.offset shouldBe Some(0) + clinitMethod.offsetEnd shouldBe Some(source.length) + cpg.file + .name("foo.php") + .content + .map(_.substring(clinitMethod.offset.get, clinitMethod.offsetEnd.get)) + .l shouldBe List(source) } } } @@ -213,9 +225,14 @@ class MemberTests extends PhpCode2CpgFixture { assign.name shouldBe Operators.assignment assign.methodFullName shouldBe Operators.assignment - inside(assign.argument.l) { case List(target: Identifier, source: Literal) => - target.name shouldBe expectedValue - target.code shouldBe expectedValue + inside(assign.argument.l) { case List(target: Call, source: Literal) => + inside(target.argument.l) { case List(base: Identifier, field: FieldIdentifier) => + base.name shouldBe "self" + field.code shouldBe expectedValue + } + + target.name shouldBe Operators.fieldAccess + target.code shouldBe s"self::$expectedValue" target.argumentIndex shouldBe 1 source.code shouldBe s"\"$expectedValue\"" diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/NamespaceTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/NamespaceTests.scala index 4b5ff757d74b..b8e062767765 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/NamespaceTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/NamespaceTests.scala @@ -1,7 +1,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.Method class NamespaceTests extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/OperatorTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/OperatorTests.scala index cc0c46830db0..1d5b184888a1 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/OperatorTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/OperatorTests.scala @@ -4,10 +4,10 @@ import io.joern.php2cpg.astcreation.AstCreator.TypeConstants import io.joern.php2cpg.parser.Domain.{PhpDomainTypeConstants, PhpOperators} import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.joern.x2cpg.utils.IntervalKeyPool +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal, TypeRef} -import io.shiftleft.passes.IntervalKeyPool -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class OperatorTests extends PhpCode2CpgFixture { @@ -497,29 +497,6 @@ class OperatorTests extends PhpCode2CpgFixture { } } - "temporary list implementation should work" in { - // TODO This is a simple placeholder implementation that represents most of the useful information - // in the AST, while being pretty much unusable for dataflow. A better implementation needs to follow. - val cpg = code(""" - listCall.methodFullName shouldBe PhpOperators.listFunc - listCall.code shouldBe "list($a,$b)" - listCall.lineNumber shouldBe Some(2) - inside(listCall.argument.l) { case List(aArg: Identifier, bArg: Identifier) => - aArg.name shouldBe "a" - aArg.code shouldBe "$a" - aArg.lineNumber shouldBe Some(2) - - bArg.name shouldBe "b" - bArg.code shouldBe "$b" - bArg.lineNumber shouldBe Some(2) - } - } - } - "include calls" should { "be correctly represented for normal includes" in { val cpg = code(""" $d) = $arr; + |""".stripMargin) + // finds the block containing the assignments + val block = cpg.all.collect { case block: Block if block.lineNumber.contains(2) => block }.head + inside(block.astChildren.assignment.l) { case tmp0 :: tmp1 :: tmp2 :: a :: b :: c :: d :: Nil => + tmp0.code shouldBe "$tmp0 = $arr" + tmp0.source.label shouldBe NodeTypes.IDENTIFIER + tmp0.source.code shouldBe "$arr" + tmp0.target.label shouldBe NodeTypes.IDENTIFIER + tmp0.target.code shouldBe "$tmp0" + + tmp1.code shouldBe "$tmp1 = $tmp0[0]" + tmp1.source.label shouldBe NodeTypes.CALL + tmp1.source.asInstanceOf[Call].name shouldBe Operators.indexAccess + tmp1.source.code shouldBe "$tmp0[0]" + tmp1.target.label shouldBe NodeTypes.IDENTIFIER + tmp1.target.code shouldBe "$tmp1" + + tmp2.code shouldBe "$tmp2 = $tmp1" + tmp2.source.label shouldBe NodeTypes.IDENTIFIER + tmp2.source.code shouldBe "$tmp1" + tmp2.target.label shouldBe NodeTypes.IDENTIFIER + tmp2.target.code shouldBe "$tmp2" + + a.code shouldBe "$a = $tmp2[0]" + a.source.label shouldBe NodeTypes.CALL + a.source.asInstanceOf[Call].name shouldBe Operators.indexAccess + a.source.code shouldBe "$tmp2[0]" + a.target.label shouldBe NodeTypes.IDENTIFIER + a.target.code shouldBe "$a" + + b.code shouldBe "$b = $tmp2[1]" + b.source.label shouldBe NodeTypes.CALL + b.source.asInstanceOf[Call].name shouldBe Operators.indexAccess + b.source.code shouldBe "$tmp2[1]" + b.target.label shouldBe NodeTypes.IDENTIFIER + b.target.code shouldBe "$b" + + c.code shouldBe "$c = $tmp0[1]" + c.source.label shouldBe NodeTypes.CALL + c.source.asInstanceOf[Call].name shouldBe Operators.indexAccess + c.source.code shouldBe "$tmp0[1]" + c.target.label shouldBe NodeTypes.IDENTIFIER + c.target.code shouldBe "$c" + + d.code shouldBe "$d = $tmp0[\"d\"]" + d.source.label shouldBe NodeTypes.CALL + d.source.asInstanceOf[Call].name shouldBe Operators.indexAccess + d.source.code shouldBe "$tmp0[\"d\"]" + d.target.label shouldBe NodeTypes.IDENTIFIER + d.target.code shouldBe "$d" + } + } } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/PocTest.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/PocTest.scala index 83a8b537fce1..ca17c9c06936 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/PocTest.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/PocTest.scala @@ -4,7 +4,7 @@ import io.joern.php2cpg.astcreation.AstCreator.TypeConstants import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class PocTest extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ScalarTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ScalarTests.scala index 0d913e2fa4cd..2f3b4e530e94 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ScalarTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/ScalarTests.scala @@ -1,7 +1,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} class ScalarTests extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeDeclTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeDeclTests.scala index 69d29a30da31..e69f35512630 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeDeclTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeDeclTests.scala @@ -3,12 +3,9 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.Config import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ModifierTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, Local, Member, Method} import io.shiftleft.semanticcpg.language.* -import io.shiftleft.codepropertygraph.generated.nodes.Block -import io.shiftleft.codepropertygraph.generated.nodes.MethodRef -import io.shiftleft.codepropertygraph.generated.nodes.TypeRef class TypeDeclTests extends PhpCode2CpgFixture { @@ -248,19 +245,27 @@ class TypeDeclTests extends PhpCode2CpgFixture { clinitMethod.filename shouldBe "foo.php" clinitMethod.file.name.l shouldBe List("foo.php") - inside(clinitMethod.body.astChildren.l) { case List(aAssign: Call, bAssign: Call) => - aAssign.code shouldBe "A = \"A\"" - inside(aAssign.astChildren.l) { case List(aIdentifier: Identifier, aLiteral: Literal) => - aIdentifier.name shouldBe "A" - aIdentifier.code shouldBe "A" + inside(clinitMethod.body.astChildren.l) { case List(self: Local, aAssign: Call, bAssign: Call) => + aAssign.code shouldBe "self::A = \"A\"" + inside(aAssign.astChildren.l) { case List(aCall: Call, aLiteral: Literal) => + inside(aCall.argument.l) { case List(aSelf: Identifier, aField: FieldIdentifier) => + aSelf.name shouldBe "self" + aField.code shouldBe "A" + } + aCall.name shouldBe Operators.fieldAccess + aCall.code shouldBe "self::A" aLiteral.code shouldBe "\"A\"" } - bAssign.code shouldBe "B = \"B\"" - inside(bAssign.astChildren.l) { case List(bIdentifier: Identifier, bLiteral: Literal) => - bIdentifier.name shouldBe "B" - bIdentifier.code shouldBe "B" + bAssign.code shouldBe "self::B = \"B\"" + inside(bAssign.astChildren.l) { case List(bCall: Call, bLiteral: Literal) => + inside(bCall.argument.l) { case List(bSelf: Identifier, bField: FieldIdentifier) => + bSelf.name shouldBe "self" + bField.code shouldBe "B" + } + bCall.name shouldBe Operators.fieldAccess + bCall.code shouldBe "self::B" bLiteral.code shouldBe "\"B\"" } @@ -309,4 +314,40 @@ class TypeDeclTests extends PhpCode2CpgFixture { } } } + + "static/const member of class should be put in method" in { + val cpg = code(""" + inside(clinitMethod.body.astChildren.l) { case List(self: Local, bAssign: Call, aAssign: Call) => + self.name shouldBe "self" + inside(aAssign.astChildren.l) { case List(aCall: Call, aLiteral: Literal) => + inside(aCall.argument.l) { case List(aSelf: Identifier, aField: FieldIdentifier) => + aSelf.name shouldBe "self" + aField.code shouldBe "A" + } + aCall.name shouldBe Operators.fieldAccess + aCall.code shouldBe "self::$A" + + aLiteral.code shouldBe "\"A\"" + } + + inside(bAssign.astChildren.l) { case List(bCall: Call, bLiteral: Literal) => + inside(bCall.argument.l) { case List(bSelf: Identifier, bField: FieldIdentifier) => + bSelf.name shouldBe "self" + bField.code shouldBe "B" + } + bCall.name shouldBe Operators.fieldAccess + bCall.code shouldBe "self::B" // Notice there is no `$` in front of the const member + + bLiteral.code shouldBe "\"B\"" + } + } + } + } } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala index c26317f4dc3f..2871427aae40 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala @@ -4,7 +4,7 @@ import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.{ModifierTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, Local, Member, Method} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.Block class TypeNodeTests extends PhpCode2CpgFixture { @@ -29,7 +29,6 @@ class TypeNodeTests extends PhpCode2CpgFixture { |""".stripMargin) "have corresponding type nodes created" in { - println(cpg.literal.toList) cpg.typ.fullName.toSet shouldEqual Set("ANY", "int") } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/UseTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/UseTests.scala index f9c911f8c617..040966272f3c 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/UseTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/UseTests.scala @@ -1,7 +1,7 @@ package io.joern.php2cpg.querying import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class UseTests extends PhpCode2CpgFixture { diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/testfixtures/PhpCode2CpgFixture.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/testfixtures/PhpCode2CpgFixture.scala index cb848636f420..0cf55b76fc97 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/testfixtures/PhpCode2CpgFixture.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/testfixtures/PhpCode2CpgFixture.scala @@ -1,10 +1,11 @@ package io.joern.php2cpg.testfixtures -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} import io.joern.php2cpg.{Config, Php2Cpg} import io.joern.x2cpg.frontendspecific.php2cpg -import io.joern.x2cpg.testfixtures.{Code2CpgFixture, LanguageFrontend, DefaultTestCpg} +import io.joern.x2cpg.testfixtures.{Code2CpgFixture, DefaultTestCpg, LanguageFrontend} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} @@ -33,14 +34,14 @@ class PhpTestCpg extends DefaultTestCpg with PhpFrontend with SemanticTestCpg { class PhpCode2CpgFixture( runOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), withPostProcessing: Boolean = true ) extends Code2CpgFixture(() => new PhpTestCpg() .withOssDataflow(runOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) - with SemanticCpgTestFixture(extraFlows) { + with SemanticCpgTestFixture(semantics) { implicit val resolver: ICallResolver = NoResolve } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ConfigFileCreationPass.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ConfigFileCreationPass.scala index fdf737a98ecb..088b0cd6638d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ConfigFileCreationPass.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ConfigFileCreationPass.scala @@ -19,7 +19,10 @@ class ConfigFileCreationPass(cpg: Cpg, requirementsTxt: String = "requirement.tx // HTM files extensionFilter(".htm"), // Requirements.txt - pathEndFilter(requirementsTxt) + pathEndFilter(requirementsTxt), + // Pipfile + pathEndFilter("Pipfile"), + pathEndFilter("Pipfile.lock") ) } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala index 0baefe4c5a6c..5df90e56804b 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala @@ -31,6 +31,7 @@ class ContextStack { val order: AutoIncIndex val variables: mutable.Map[String, nodes.NewNode] var lambdaCounter: Int + val methodCounter: mutable.Map[String, Int] } private class MethodContext( @@ -43,7 +44,8 @@ class ContextStack { val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, val globalVariables: mutable.Set[String] = mutable.Set.empty, val nonLocalVariables: mutable.Set[String] = mutable.Set.empty, - var lambdaCounter: Int = 0 + var lambdaCounter: Int = 0, + val methodCounter: mutable.Map[String, Int] = mutable.Map.empty ) extends Context {} private class ClassContext( @@ -51,7 +53,8 @@ class ContextStack { val astParent: nodes.NewNode, val order: AutoIncIndex, val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 + var lambdaCounter: Int = 0, + val methodCounter: mutable.Map[String, Int] = mutable.Map.empty ) extends Context {} // Used to represent comprehension variable and exception @@ -68,7 +71,8 @@ class ContextStack { val astParent: nodes.NewNode, val order: AutoIncIndex, val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 + var lambdaCounter: Int = 0, + val methodCounter: mutable.Map[String, Int] = mutable.Map.empty ) extends Context {} private case class VariableReference( @@ -156,7 +160,7 @@ class ContextStack { def createIdentifierLinks( createLocal: (String, Option[String]) => nodes.NewLocal, - createClosureBinding: (String, String) => nodes.NewClosureBinding, + createClosureBinding: (String) => nodes.NewClosureBinding, createAstEdge: (nodes.NewNode, nodes.NewNode, Int) => Unit, createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit @@ -260,7 +264,7 @@ class ContextStack { */ def considerAsGlobalVariable(lhs: NewNode): Unit = { lhs match { - case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains("") => + case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains(Constants.moduleName) => addGlobalVariable(n.name) case _ => } @@ -291,7 +295,7 @@ class ContextStack { private def linkLocalOrCapturing( createLocal: (String, Option[String]) => NewLocal, - createClosureBinding: (String, String) => NewClosureBinding, + createClosureBinding: (String) => NewClosureBinding, createAstEdge: (NewNode, NewNode, Int) => Unit, createRefEdge: (NewNode, NewNode) => Unit, createCaptureEdge: (NewNode, NewNode) => Unit, @@ -310,39 +314,43 @@ class ContextStack { case methodContext: MethodContext => // Context is only relevant for linking if it is not a class body methods context // or the identifier/reference itself is from the class body method context. - if (!methodContext.isClassBodyMethod || methodContext == startContext) { - contextHasVariable = context.variables.contains(name) - - val closureBindingId = - methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name - - if (!contextHasVariable) { - if (context != moduleMethodContext.get) { - val localNode = createLocal(name, Some(closureBindingId)) - transferLineColInfo(identifier, localNode) - createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) - methodContext.variables.put(name, localNode) - } else { - // When we could not even find a matching variable in the module context we get - // here and create a local so that we can link something and fullfil the CPG - // format requirements. - // For example this happens when there are wildcard imports directly into the - // modules namespace. - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) - methodContext.variables.put(name, localNode) - } + val mangledName = + if (!methodContext.isClassBodyMethod || methodContext == startContext) { + name + } else { + s"$name" } - val localNodeInContext = methodContext.variables(name) + contextHasVariable = context.variables.contains(mangledName) - createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + val closureBindingId = + methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name - if (!contextHasVariable && context != moduleMethodContext.get) { - identifierOrClosureBindingToLink = createClosureBinding(closureBindingId, name) - createCaptureEdge(identifierOrClosureBindingToLink, methodContext.methodRefNode.get) + if (!contextHasVariable) { + if (context != moduleMethodContext.get) { + val localNode = createLocal(mangledName, Some(closureBindingId)) + transferLineColInfo(identifier, localNode) + createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) + methodContext.variables.put(mangledName, localNode) + } else { + // When we could not even find a matching variable in the module context we get + // here and create a local so that we can link something and fullfil the CPG + // format requirements. + // For example this happens when there are wildcard imports directly into the + // modules namespace. + val localNode = createLocal(mangledName, None) + transferLineColInfo(identifier, localNode) + createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) + methodContext.variables.put(mangledName, localNode) } } + val localNodeInContext = methodContext.variables(mangledName) + + createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + + if (!contextHasVariable && context != moduleMethodContext.get) { + identifierOrClosureBindingToLink = createClosureBinding(closureBindingId) + createCaptureEdge(identifierOrClosureBindingToLink, methodContext.methodRefNode.get) + } case specialBlockContext: SpecialBlockContext => contextHasVariable = context.variables.contains(name) if (contextHasVariable) { @@ -429,4 +437,8 @@ class ContextStack { }) } + def methodCounter: mutable.Map[String, Int] = { + stack.head.methodCounter + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala index 0afff1632080..01a932277c44 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala @@ -2,7 +2,7 @@ package io.joern.pysrc2cpg import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.{NewDependency} import org.slf4j.{Logger, LoggerFactory} diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/EdgeBuilder.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/EdgeBuilder.scala index fdb409d6ab0b..2973108ba84d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/EdgeBuilder.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/EdgeBuilder.scala @@ -23,7 +23,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewUnknown } import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class EdgeBuilder(diffGraph: DiffGraphBuilder) { def astEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, order: Int): Unit = { diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Main.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Main.scala index 2ab54ea2c507..7bf13ec760c1 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Main.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Main.scala @@ -3,6 +3,7 @@ package io.joern.pysrc2cpg import io.joern.pysrc2cpg.Frontend.cmdLineParser import io.joern.x2cpg.X2CpgMain import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser import java.nio.file.Paths @@ -38,9 +39,15 @@ private object Frontend { } } -object NewMain extends X2CpgMain(cmdLineParser, new Py2CpgOnFileSystem())(new Py2CpgOnFileSystemConfig()) { +object NewMain + extends X2CpgMain(cmdLineParser, new Py2CpgOnFileSystem())(Py2CpgOnFileSystemConfig()) + with FrontendHTTPServer[Py2CpgOnFileSystemConfig, Py2CpgOnFileSystem] { + + override protected def newDefaultConfig(): Py2CpgOnFileSystemConfig = Py2CpgOnFileSystemConfig() + def run(config: Py2CpgOnFileSystemConfig, frontend: Py2CpgOnFileSystem): Unit = { - frontend.run(config) + if (config.serverMode) { startup() } + else { frontend.run(config) } } def getCmdLineParser = cmdLineParser diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala index e1364856880e..6a4e0d837d92 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala @@ -7,7 +7,7 @@ import io.joern.x2cpg.Defines import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.joern.x2cpg.utils.NodeBuilders import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies, nodes} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class NodeBuilder(diffGraph: DiffGraphBuilder) { @@ -126,12 +126,12 @@ class NodeBuilder(diffGraph: DiffGraphBuilder) { addNodeToDiff(methodRefNode) } - def closureBindingNode(closureBindingId: String, closureOriginalName: String): nodes.NewClosureBinding = { + def closureBindingNode(closureBindingId: String): nodes.NewClosureBinding = { val closureBindingNode = nodes .NewClosureBinding() .closureBindingId(Some(closureBindingId)) .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) - .closureOriginalName(Some(closureOriginalName)) + .closureOriginalName(None) addNodeToDiff(closureBindingNode) } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2Cpg.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2Cpg.scala index e951442986bb..b8ab7f82051b 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2Cpg.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2Cpg.scala @@ -2,10 +2,10 @@ package io.joern.pysrc2cpg import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants -import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import io.shiftleft.codepropertygraph.generated.Languages -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder +import flatgraph.DiffGraphApplier +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder object Py2Cpg { case class InputPair(content: String, relFileName: String) @@ -45,7 +45,7 @@ class Py2Cpg( val anyTypeDecl = nodeBuilder.typeDeclNode(Constants.ANY, Constants.ANY, "N/A", Nil, LineAndColumn(1, 1, 1, 1, 1, 1)) edgeBuilder.astEdge(anyTypeDecl, globalNamespaceBlock, 0) - BatchedUpdate.applyDiff(outputCpg.graph, diffGraph) + DiffGraphApplier.applyDiff(outputCpg.graph, diffGraph) new CodeToCpg(outputCpg, inputProviders, schemaValidationMode, enableFileContent).createAndApply() new ConfigFileCreationPass(outputCpg, requirementsTxt).createAndApply() new DependenciesFromRequirementsTxtPass(outputCpg).createAndApply() diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2CpgOnFileSystem.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2CpgOnFileSystem.scala index 58e58dd8da88..391e044dca21 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2CpgOnFileSystem.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/Py2CpgOnFileSystem.scala @@ -8,7 +8,7 @@ import org.slf4j.LoggerFactory import java.nio.file.* import scala.util.Try -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* case class Py2CpgOnFileSystemConfig( venvDir: Option[Path] = None, diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala index f1b8f7993a4b..0aeb0afda295 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala @@ -4,12 +4,13 @@ import PythonAstVisitor.{logger, metaClassSuffix, noLineAndColumn} import io.joern.pysrc2cpg.memop.* import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants.builtinPrefix import io.joern.pythonparser.ast +import io.joern.pythonparser.ast.{Arguments, iexpr, istmt} import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.joern.x2cpg.{AstCreatorBase, ValidationMode} import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewIdentifier, NewNode, NewTypeDecl} import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import scala.collection.mutable @@ -34,6 +35,8 @@ class PythonAstVisitor( extends AstCreatorBase(relFileName) with PythonAstVisitorHelpers { + private val redefintionSuffix = "$redefinition" + private val diffGraph = Cpg.newDiffGraphBuilder protected val nodeBuilder = new NodeBuilder(diffGraph) protected val edgeBuilder = new EdgeBuilder(diffGraph) @@ -93,7 +96,7 @@ class PythonAstVisitor( edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1) contextStack.setFileNamespaceBlock(namespaceBlockNode) - val methodFullName = calculateFullNameFromContext("") + val methodFullName = calculateFullNameFromContext(Constants.moduleName) val firstLineAndCol = module.stmts.headOption.map(lineAndColOf) val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf) @@ -106,9 +109,9 @@ class PythonAstVisitor( val moduleMethodNode = createMethod( - "", + Constants.moduleName, methodFullName, - Some(""), + Some(Constants.moduleName), ModifierTypes.VIRTUAL :: ModifierTypes.MODULE :: Nil, parameterProvider = () => MethodParameters.empty(), bodyProvider = () => createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), @@ -224,26 +227,6 @@ class PythonAstVisitor( } } - def convert(functionDef: ast.FunctionDef): NewNode = { - val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) - val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction(functionDef.args, isStaticMethod(functionDef.decorator_list)), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = false, - lineAndColOf(functionDef) - ) - functionDefToMethod.put(functionDef, methodNode) - - val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) - - createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) - } - /* * For a decorated function like: * @f1(arg) @@ -267,26 +250,58 @@ class PythonAstVisitor( ) } - def convert(functionDef: ast.AsyncFunctionDef): NewNode = { + private def convertFunctionInternal( + name: String, + args: Arguments, + decoratorList: ast.CollType[iexpr], + body: ast.CollType[istmt], + returns: Option[iexpr], + isAsync: Boolean, + functionDef: istmt + ): NewNode = { val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) + createIdentifierNode(name, Store, lineAndColOf(functionDef)) val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction(functionDef.args, isStaticMethod(functionDef.decorator_list)), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = true, + name, + Some(name), + createParameterProcessingFunction(args, isStaticMethod(decoratorList)), + () => body.map(convert), + returns, + isAsync, lineAndColOf(functionDef) ) functionDefToMethod.put(functionDef, methodNode) val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + wrapMethodRefWithDecorators(methodRefNode, decoratorList) createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) } + def convert(functionDef: ast.FunctionDef): NewNode = { + convertFunctionInternal( + functionDef.name, + functionDef.args, + functionDef.decorator_list, + functionDef.body, + functionDef.returns, + isAsync = false, + functionDef + ) + } + + def convert(functionDef: ast.AsyncFunctionDef): NewNode = { + convertFunctionInternal( + functionDef.name, + functionDef.args, + functionDef.decorator_list, + functionDef.body, + functionDef.returns, + isAsync = true, + functionDef + ) + } + private def isStaticMethod(decoratorList: Iterable[ast.iexpr]): Boolean = { decoratorList.exists { case name: ast.Name if name.id == "staticmethod" => true @@ -325,7 +340,14 @@ class PythonAstVisitor( lineAndColumn: LineAndColumn, additionalModifiers: List[String] = List.empty ): (nodes.NewMethod, nodes.NewMethodRef) = { - val methodFullName = calculateFullNameFromContext(methodName) + val suffix = + contextStack.methodCounter.get(methodName) match { + case Some(counter) => + redefintionSuffix + counter.toString + case None => + "" + } + val methodFullName = calculateFullNameFromContext(methodName) + suffix val methodRefNode = nodeBuilder.methodRefNode("def " + methodName + "(...)", methodFullName, lineAndColumn) @@ -345,6 +367,11 @@ class PythonAstVisitor( lineAndColumn ) + contextStack.methodCounter.updateWith(methodName) { + case None => Some(1) + case Some(counter) => Some(counter + 1) + } + (methodNode, methodRefNode) } @@ -403,7 +430,7 @@ class PythonAstVisitor( // For every method that is a module, the local variables can be imported by other modules. This behaviour is // much like fields so they are to be linked as fields to this method type - if (name == "") contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) + if (name == Constants.moduleName) contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) contextStack.pop() edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) @@ -473,7 +500,11 @@ class PythonAstVisitor( val (_, methodRefNode) = createMethodAndMethodRef( classBodyFunctionName, scopeName = None, - parameterProvider = () => MethodParameters.empty(), + parameterProvider = () => + MethodParameters( + 0, + nodeBuilder.methodParameterNode("cls", isVariadic = false, lineAndColOf(classDef), Option(0)) :: Nil + ), bodyProvider = () => classDef.body.map(convert), None, isAsync = false, @@ -488,7 +519,7 @@ class PythonAstVisitor( val functions = classDef.body.collect { case func: ast.FunctionDef => func } // __init__ method has to be in functions because "async def __init__" is invalid. - val initFunctionOption = functions.find(_.name == "__init__") + val initFunctionOption = functions.find(_.name == Constants.initName) val initParameters = initFunctionOption.map(_.args).getOrElse { // Create arguments of a default __init__ function. @@ -522,40 +553,58 @@ class PythonAstVisitor( // For non static methods we create an adapter method which basically only shifts the parameters // one to the left and makes sure that the meta class object is not passed to func as instance // parameter. - classDef.body.foreach { - case func: ast.FunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode - ) - case func: ast.AsyncFunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode - ) - case _ => - // All other body statements are currently ignored. - } + classDef.body + // Filter for functions and build tuples with their name + .flatMap { + case func: ast.FunctionDef => + Some((func.name, func)) + case func: ast.AsyncFunctionDef => + Some((func.name, func)) + case _ => + None + } + // Group by name and remove name from value + .groupMap(_._1)(_._2) + // Sort by name to get a stable output + .toBuffer + .sortBy(_._1) + // Take the last function. We only create member/binding + // for the last definition as it overwrites the previous ones. + .map { case (_, functions) => functions.last } + .foreach { + case func: ast.FunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + case func: ast.AsyncFunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + } contextStack.pop() - // Create call to function and assignment of the meta class object to a identifier named - // like the class. - val callToClassBodyFunction = createCall(methodRefNode, "", lineAndColOf(classDef), Nil, Nil) val metaTypeRefNode = createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColOf(classDef)) val classIdentifierAssignNode = createAssignmentToIdentifier(classDef.name, metaTypeRefNode, lineAndColOf(classDef)) + // Create call to function and assignment of the meta class object to a identifier named + // like the class. + val classIdentifierForCall = createIdentifierNode(classDef.name, Load, lineAndColOf(classDef)) + val callToClassBodyFunction = + createInstanceCall(methodRefNode, classIdentifierForCall, "", lineAndColOf(classDef), Nil, Nil) - val classBlock = createBlock(callToClassBodyFunction :: classIdentifierAssignNode :: Nil, lineAndColOf(classDef)) + val classBlock = createBlock(classIdentifierAssignNode :: callToClassBodyFunction :: Nil, lineAndColOf(classDef)) classBlock } @@ -774,7 +823,7 @@ class PythonAstVisitor( val initCall = createXDotYCall( () => createIdentifierNode("cls", Load, lineAndColumn), - "__init__", + Constants.initName, xMayHaveSideEffects = false, lineAndColumn, argumentWithInstance, @@ -812,11 +861,30 @@ class PythonAstVisitor( val loweredNodes = createValueToTargetsDecomposition(assign.targets, convert(assign.value), lineAndColOf(assign)) - if (loweredNodes.size == 1) { + val assignmentsToMembers = + if (contextStack.isClassContext) { + // In addition to the left hand side identifier(s) created by createValueToTargetsDecomposition + // we here create `cls. = ` if we are in a class body function to + // represent the assignment into a member of the same name in the meta class. + assign.targets.collect { case nameTarget: ast.Name => + assert(memOpMap.get(nameTarget).get == Store) + val lineAndColumn = lineAndColOf(nameTarget) + val classIdentifier = createIdentifierNode("cls", Load, lineAndColumn) + val targetFieldAccess = createFieldAccess(classIdentifier, nameTarget.id, lineAndColumn) + val targetIdentifier = createIdentifierNode(nameTarget.id, Load, lineAndColumn) + createAssignment(targetFieldAccess, targetIdentifier, lineAndColumn) + } + } else { + Nil + } + + val combinedLoweredNodes = loweredNodes ++ assignmentsToMembers + + if (combinedLoweredNodes.size == 1) { // Simple assignment can be returned directly. - loweredNodes.head + combinedLoweredNodes.head } else { - createBlock(loweredNodes, lineAndColOf(assign)) + createBlock(combinedLoweredNodes, lineAndColOf(assign)) } } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstPrinter.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstPrinter.scala index 18a3cb30d32c..3e8864537c8b 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstPrinter.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstPrinter.scala @@ -1,5 +1,5 @@ package io.joern.pythonparser -import io.joern.pythonparser.ast._ +import io.joern.pythonparser.ast.* import scala.collection.immutable class AstPrinter(indentStr: String) extends AstVisitor[String] { diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstVisitor.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstVisitor.scala index 4a9617055bdd..167c2287c689 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstVisitor.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/AstVisitor.scala @@ -1,6 +1,6 @@ package io.joern.pythonparser -import io.joern.pythonparser.ast._ +import io.joern.pythonparser.ast.* trait AstVisitor[T] { def visit(ast: iast): T diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/PyParser.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/PyParser.scala index 36040ca8274b..826960e7f439 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/PyParser.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/PyParser.scala @@ -6,7 +6,7 @@ import io.joern.pythonparser.ast.{ErrorStatement, iast} import java.io.{BufferedReader, ByteArrayInputStream, InputStream, Reader} import java.nio.charset.StandardCharsets -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class PyParser { private var pythonParser: PythonParser = scala.compiletime.uninitialized diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/ast/Ast.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/ast/Ast.scala index 02c5272c4df7..0ba09590deb0 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/ast/Ast.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pythonparser/ast/Ast.scala @@ -3,7 +3,7 @@ package io.joern.pythonparser.ast import io.joern.pythonparser.AstVisitor import java.util -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* // This file describes the AST classes. // It tries to stay as close as possible to the AST defined by CPython at diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/config/ConfigTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/config/ConfigTests.scala index 49b8be9900b0..b8bec030fcf3 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX", // Frontend-specific args diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssertCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssertCpgTests.scala index 1f7ef899b70f..28cce1f5d986 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssertCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssertCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala index dca961f63eef..2811650be190 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AttributeCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AttributeCpgTests.scala index e64fb1369e79..21d6ca62a217 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AttributeCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AttributeCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BinOpCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BinOpCpgTests.scala index 1e2dc03cb42c..7b0917042e1e 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BinOpCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BinOpCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BoolOpCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BoolOpCpgTests.scala index 0d5ffa521e3a..c210a453b33c 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BoolOpCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BoolOpCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BuiltinIdentifierTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BuiltinIdentifierTests.scala index fff6c7d5889c..3dae7d13a396 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BuiltinIdentifierTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BuiltinIdentifierTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BytesLiteralCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BytesLiteralCpgTests.scala index 1238e6ed089b..4aafd1d33add 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BytesLiteralCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/BytesLiteralCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala index 7ec13f219950..3337cc5bf97d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala @@ -1,9 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.NodeOps +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ClassCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ClassCpgTests.scala index aed25d40e211..eb59960a2c15 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ClassCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ClassCpgTests.scala @@ -1,7 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.Call +import io.shiftleft.semanticcpg.language.* class ClassCpgTests extends PySrc2CpgFixture(withOssDataflow = false) { "class" should { @@ -109,4 +110,15 @@ class ClassCpgTests extends PySrc2CpgFixture(withOssDataflow = false) { } } + "assignment in class body should result into member write on meta class object" in { + val cpg = code("""class Foo: + | AAA = 111 + |""".stripMargin) + + val List(bodyMethod) = cpg.method.name("").l + val List(block) = bodyMethod.topLevelExpressions.l + val List(_, memberAssignment) = block.astChildren.collectAll[Call].l + memberAssignment.code shouldBe "cls.AAA = AAA" + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CompareCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CompareCpgTests.scala index f01d5f698aa7..3550933f5e3f 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CompareCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CompareCpgTests.scala @@ -1,9 +1,9 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ContentCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ContentCpgTests.scala index ccda1f361591..ed60926e22c7 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ContentCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ContentCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.semanticcpg.language.* class ContentCpgTests extends PySrc2CpgFixture() { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DeleteCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DeleteCpgTests.scala index cc1aa72bf86f..a0b02eeed10e 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DeleteCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DeleteCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DictCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DictCpgTests.scala index c1698709dd6c..2e996ab2c924 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DictCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/DictCpgTests.scala @@ -1,10 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.{Py2CpgTestContext, PySrc2CpgFixture} -import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.semanticcpg.language.* -import org.scalatest.matchers.should.Matchers -import org.scalatest.wordspec.AnyWordSpec class DictCpgTests extends PySrc2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FormatStringCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FormatStringCpgTests.scala index 113401a91718..586b04fe7dd2 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FormatStringCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FormatStringCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FunctionDefCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FunctionDefCpgTests.scala index 8c526ec859bb..d9dde2aa24f2 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FunctionDefCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/FunctionDefCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.Call @@ -55,12 +55,12 @@ class FunctionDefCpgTests extends AnyFreeSpec with Matchers { } "test function method ref" in { - cpg.methodRef("func").referencedMethod.fullName.head shouldBe + cpg.methodRefWithName("func").referencedMethod.fullName.head shouldBe "test.py:.func" } "test assignment of method ref to local variable" in { - val assignNode = cpg.methodRef("func").astParent.isCall.head + val assignNode = cpg.methodRefWithName("func").astParent.isCall.head assignNode.code shouldBe "func = def func(...)" } @@ -132,7 +132,7 @@ class FunctionDefCpgTests extends AnyFreeSpec with Matchers { |""".stripMargin) "test decorator wrapping of method reference" in { - val (staticMethod: Call) :: Nil = cpg.methodRef("func").astParent.l: @unchecked + val (staticMethod: Call) :: Nil = cpg.methodRefWithName("func").astParent.l: @unchecked staticMethod.code shouldBe "staticmethod(def func(...))" staticMethod.name shouldBe "staticmethod" val (abc: Call) :: Nil = staticMethod.start.astParent.l: @unchecked diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IfCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IfCpgTests.scala index cbb02e48d5ba..6647fe7454d1 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IfCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IfCpgTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, nodes} import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ImportCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ImportCpgTests.scala index 122cccd6d386..4d066ba21c45 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ImportCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ImportCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IntLiteralCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IntLiteralCpgTests.scala index 5ae76bdb79f7..d3c94e23307d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IntLiteralCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/IntLiteralCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ListCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ListCpgTests.scala index 73f85aa73fab..b02fd1a9c3cc 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ListCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ListCpgTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MemberCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MemberCpgTests.scala index f8e88367fc3c..448cd17a4f9c 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MemberCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MemberCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.nodes.Member -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MethodCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MethodCpgTests.scala index 6e03d0808b92..84c17a05faa9 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MethodCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/MethodCpgTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers @@ -22,4 +22,33 @@ class MethodCpgTests extends AnyFreeSpec with Matchers { } } + "test method redefinition" in { + val cpg = Py2CpgTestContext.buildCpg( + """ + |class Foo(): + | def method(): + | pass + | def method(): + | pass + | def method(): + | pass + |""".stripMargin, + "a/b.py" + ) + + cpg.method.name("method").map(m => (m.name, m.fullName)).l should contain theSameElementsAs (List( + ("method", "a/b.py:.Foo.method"), + ("method", "a/b.py:.Foo.method$redefinition1"), + ("method", "a/b.py:.Foo.method$redefinition2") + )) + + cpg.typeDecl.name("Foo").member.name("method").dynamicTypeHintFullName.l should contain theSameElementsAs ( + List("a/b.py:.Foo.method$redefinition2") + ) + + cpg.typeDecl.name("Foo").member.name("method").dynamicTypeHintFullName.l should contain theSameElementsAs ( + List("a/b.py:.Foo.method") + ) + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala index 57c274c7502d..a35deb9282e9 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/PatternMatchingTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/PatternMatchingTests.scala index 75bacc6c1dac..9510fd3d5ac7 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/PatternMatchingTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/PatternMatchingTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.codepropertygraph.generated.NodeTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class PatternMatchingTests extends PySrc2CpgFixture() { "pattern matching" should { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/RaiseCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/RaiseCpgTests.scala index 58dd9b40fae4..32c14bcb75e9 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/RaiseCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/RaiseCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.Operators -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ReturnCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ReturnCpgTests.scala index 631bddb4aa5e..1b17c2a3b039 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ReturnCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ReturnCpgTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SetCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SetCpgTests.scala index d23fa0e133af..10f518a9a70a 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SetCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SetCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SliceCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SliceCpgTests.scala index ece135da41e0..f48c201a8aba 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SliceCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SliceCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.semanticcpg.language.* class SliceCpgTests extends PySrc2CpgFixture() { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StarredCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StarredCpgTests.scala index 67c53bd02f5a..0b90acbbfa41 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StarredCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StarredCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StrLiteralCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StrLiteralCpgTests.scala index 2cf2a22b2cda..9a3455c7abd3 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StrLiteralCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StrLiteralCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StringExpressionListCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StringExpressionListCpgTests.scala index e926441a1940..feb7aff51640 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StringExpressionListCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/StringExpressionListCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SubscriptCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SubscriptCpgTests.scala index 134d0bd16608..adf4531c24ea 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SubscriptCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/SubscriptCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/TryCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/TryCpgTests.scala index 103473b7bf72..92a0f02ba0a8 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/TryCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/TryCpgTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/UnaryOpCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/UnaryOpCpgTests.scala index 0c7739cda935..f98cb73ca464 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/UnaryOpCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/UnaryOpCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala index 9da932fb2023..ad99942b9a13 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers @@ -149,11 +149,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test method reference closure binding" in { - val methodRefNode = cpg.methodRef("f").head + val methodRefNode = cpg.methodRefWithName("f").head val closureBinding = methodRefNode._closureBindingViaCaptureOut.next() closureBinding.closureBindingId shouldBe Some("test.py:.f:x") closureBinding.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE - closureBinding.closureOriginalName shouldBe Some("x") } "test global variable exists" in { @@ -185,11 +184,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test method reference closure binding" in { - val methodRefNode = cpg.methodRef("f").head + val methodRefNode = cpg.methodRefWithName("f").head val closureBinding = methodRefNode._closureBindingViaCaptureOut.next() closureBinding.closureBindingId shouldBe Some("test.py:.f:x") closureBinding.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE - closureBinding.closureOriginalName shouldBe Some("x") } "test global variable exists" in { @@ -223,11 +221,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test method reference closure binding of f in g" in { - val methodRefNode = cpg.methodRef("f").head + val methodRefNode = cpg.methodRefWithName("f").head val closureBinding = methodRefNode._closureBindingViaCaptureOut.next() closureBinding.closureBindingId shouldBe Some("test.py:.g.f:x") closureBinding.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE - closureBinding.closureOriginalName shouldBe Some("x") } "test local variable exists in g" in { @@ -240,11 +237,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test method reference closure binding of g in module" in { - val methodRefNode = cpg.methodRef("g").head + val methodRefNode = cpg.methodRefWithName("g").head val closureBinding = methodRefNode._closureBindingViaCaptureOut.next() closureBinding.closureBindingId shouldBe Some("test.py:.g:x") closureBinding.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE - closureBinding.closureOriginalName shouldBe Some("x") } "test global variable exists" in { @@ -259,7 +255,8 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "reference from class method" - { - lazy val cpg = Py2CpgTestContext.buildCpg("""x = 0 + lazy val cpg = Py2CpgTestContext.buildCpg(""" + |x = 0 |class MyClass(): | x = 1 | def f(self): @@ -267,15 +264,20 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { |""".stripMargin) "test capturing to global x exists" in { - val moduleLocal = cpg.method.name("").local.name("x").head - moduleLocal._closureBindingViaRefIn.next().closureBindingId shouldBe Some("test.py:.MyClass.f:x") + val moduleXLocal = cpg.method.name("").local.name("x").head + val moduleXBinding = moduleXLocal._closureBindingViaRefIn.next() + moduleXBinding.closureBindingId shouldBe Some("test.py:.MyClass.:x") + + val bodyXLocal = cpg.method.fullName("test.py:.MyClass.").local.name("x").head + bodyXLocal.closureBindingId shouldBe None - val bodyLocal = cpg.method.fullName("test.py:.MyClass.").local.name("x").head - bodyLocal.closureBindingId shouldBe None + val capturedBodyXLocal = cpg.method.fullName("test.py:.MyClass.").local.name("x").head + capturedBodyXLocal.closureBindingId shouldBe Some("test.py:.MyClass.:x") + val bodyXBinding = capturedBodyXLocal._closureBindingViaRefIn.next() + bodyXBinding.closureBindingId shouldBe Some("test.py:.MyClass.f:x") val fLocal = cpg.method.fullName("test.py:.MyClass.f").local.name("x").head fLocal.closureBindingId shouldBe Some("test.py:.MyClass.f:x") - } } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/WhileCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/WhileCpgTests.scala index 661c3bf6cb5e..b2c771250c7a 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/WhileCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/WhileCpgTests.scala @@ -1,6 +1,7 @@ package io.joern.pysrc2cpg.cpg -import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ + +import io.joern.pysrc2cpg.testfixtures.Py2CpgTestContext +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, nodes} import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala index 034b5183c0b7..cf3812d311db 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala @@ -1,8 +1,16 @@ package io.joern.pysrc2cpg.dataflow +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.toExtendedCfgNode -import io.joern.dataflowengineoss.semanticsloader.{FlowMapping, FlowSemantic, PassThroughMapping} -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.dataflowengineoss.semanticsloader.{ + FlowMapping, + FlowSemantic, + NilSemantics, + NoCrossTaintSemantics, + NoSemantics, + PassThroughMapping +} +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Literal, Member, Method} import io.shiftleft.semanticcpg.language.* @@ -63,7 +71,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |a = 20 |print(foo(a)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List()))) + .withSemantics(DefaultSemantics().after(NilSemantics.where(List("helpers.py:.foo")))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -76,7 +84,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |a = 20 |print(foo(a)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(0, 0))))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(0, 0)))))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -89,7 +97,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |a = 20 |print(foo(a)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(1, 1))))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(1, 1)))))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -101,7 +109,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |from helpers import foo |print(foo(20)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List()))) + .withSemantics(DefaultSemantics().after(NilSemantics.where(List("helpers.py:.foo")))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -113,7 +121,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |from helpers import foo |print(foo(20)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(0, 0))))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(0, 0)))))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -125,7 +133,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |from helpers import foo |print(foo(20)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(1, 1))))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic("helpers.py:.foo", List(FlowMapping(1, 1)))))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -140,7 +148,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |a = 20 |print(foo(a)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List()))) + .withSemantics(DefaultSemantics().after(NilSemantics.where(List("Test0.py:.foo")))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l @@ -155,68 +163,24 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |a = 20 |print(foo(a)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(0, 0))))) - val source = cpg.literal("20").l - val sink = cpg.call("print").argument(1).l - val flows = sink.reachableByFlows(source).l - flows shouldBe empty - } - - "no flow from aliased literal to method call return value given argument1-only semantics" ignore { - val cpg = code(""" - |def foo(x): - | return x - | - |a = 20 - |print(foo(a)) - |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(1, 1))))) - val source = cpg.literal("20").l - val sink = cpg.call("print").argument(1).l - val flows = sink.reachableByFlows(source).l - flows shouldBe empty - } - - "no flow from literal to method call return value given empty semantics" ignore { - val cpg = code(""" - |def foo(x): - | return x - | - |print(foo(20)) - |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List()))) - val source = cpg.literal("20").l - val sink = cpg.call("print").argument(1).l - val flows = sink.reachableByFlows(source).l - flows shouldBe empty - } - - "no flow from literal to method call return value given receiver-only semantics" ignore { - val cpg = code(""" - |def foo(x): - | return x - | - |print(foo(20)) - |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(0, 0))))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(0, 0)))))) val source = cpg.literal("20").l val sink = cpg.call("print").argument(1).l val flows = sink.reachableByFlows(source).l flows shouldBe empty } - "no flow from literal to method call return value given argument1-only semantics" ignore { + "don't taint the return value when specifying a named argument" in { val cpg = code(""" - |def foo(x): - | return x - | - |print(foo(20)) + |import foo + |foo.bar(foo.baz(A=1)) |""".stripMargin) - .withExtraFlows(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(1, 1))))) - val source = cpg.literal("20").l - val sink = cpg.call("print").argument(1).l - val flows = sink.reachableByFlows(source).l - flows shouldBe empty + // The taint spec for `baz` here says that its argument "A" only taints itself. This is to make sure + // its return value is not tainted even when we are using `-1` in the spec. + .withSemantics(DefaultSemantics().plus(List(FlowSemantic(".*baz", List(FlowMapping(-1, "A", -1, "A")), true)))) + val one = cpg.literal("1") + val bar = cpg.call("bar").argument + bar.reachableByFlows(one).map(flowToResultPairs) shouldBe empty } "chained call" in { @@ -660,7 +624,7 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { |x = {'x': 10} |print(1, x) |""".stripMargin) - .withExtraFlows(List(FlowSemantic(".*print", List(PassThroughMapping), true))) + .withSemantics(DefaultSemantics().plus(List(FlowSemantic(".*print", List(PassThroughMapping), true)))) def source = cpg.literal def sink = cpg.call("print").argument.argumentIndex(2) @@ -839,21 +803,259 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { ) } + "flow from literal to an external method's named argument using two same-methodFullNamed semantics" in { + val cpg = code(""" + |import bar + |x = 'foobar' + |bar.foo(Baz=x) + |""".stripMargin) + .withSemantics( + DefaultSemantics().plus( + List( + // Equivalent to a single `FlowSemantic` entry with both FlowMappings + FlowSemantic("bar.py:.foo", List(PassThroughMapping)), + FlowSemantic("bar.py:.foo", List(FlowMapping(0, 0))) + ) + ) + ) + + val source = cpg.literal("'foobar'") + val sink = cpg.call("foo").argument.argumentName("Baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("x = 'foobar'", 3), ("bar.foo(Baz = x)", 4)) + ) + } + +} + +// Showcases that, even though `foo` is defined in the source-code, we are still able to override its semantics. +// Note that using `withSemantics` only updates the query-time semantics. +class InternalMethodCustomSemanticsDataFlowTest + extends PySrc2CpgFixture( + withOssDataflow = true, + semantics = DefaultSemantics().plus(List(FlowSemantic("Test0.py:.foo", List(FlowMapping(1, 1))))) + ) { + + "no flow from literal to method call return value" in { + val cpg = code(""" + |def foo(x): + | return x + | + |print(foo(20)) + |""".stripMargin) + val source = cpg.literal("20") + val sink = cpg.call("print").argument(1) + val flows = sink.reachableByFlows(source) + flows shouldBe empty + } + + "no flow from literal (in an assignment) to method call return value" in { + val cpg = code(""" + |def foo(x): + | return x + | + |a = 20 + |print(foo(a)) + |""".stripMargin) + val source = cpg.literal("20") + val sink = cpg.call("print").argument(1) + val flows = sink.reachableByFlows(source) + flows shouldBe empty + } +} + +class DefaultSemanticsDataFlowTest1 extends PySrc2CpgFixture(withOssDataflow = true, semantics = DefaultSemantics()) { + + "DefaultSemantics cross-taints arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |bar.foo(b, Z=a) + |bar.baz(b) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("a = 1", 3), ("bar.foo(b, Z = a)", 4), ("bar.baz(b)", 5)) + ) + } + + "DefaultSemantics taints external method call return values" in { + val cpg = code(""" + |import bar + |y = 1 + |x = bar.foo(y) + |bar.baz(x) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("y = 1", 3), ("bar.foo(y)", 4), ("x = bar.foo(y)", 4), ("bar.baz(x)", 5)) + ) + } + +} + +class NoSemanticsDataFlowTest1 extends PySrc2CpgFixture(withOssDataflow = true, semantics = NoSemantics) { + + "NoSemantics cross-taints arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |bar.foo(b, Z=a) + |bar.baz(b) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("a = 1", 3), ("bar.foo(b, Z = a)", 4), ("bar.baz(b)", 5)) + ) + } + + "NoSemantics taints external method call return values" in { + val cpg = code(""" + |import bar + |y = 1 + |x = bar.foo(y) + |bar.baz(x) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("y = 1", 3), ("bar.foo(y)", 4), ("x = bar.foo(y)", 4), ("bar.baz(x)", 5)) + ) + } +} + +class NilSemanticsDataFlowTest1 + extends PySrc2CpgFixture(withOssDataflow = true, semantics = NilSemantics().after(DefaultSemantics())) { + + "NilSemantics does not cross-taint arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |bar.foo(b, Z=a) + |bar.baz(b) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs) shouldBe empty + } + + "NilSemantics does not taint external method call return values" in { + val cpg = code(""" + |import bar + |y = 1 + |x = bar.foo(y) + |bar.baz(x) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz") + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List() + } +} + +class NoCrossTaintDataFlowTest1 + extends PySrc2CpgFixture( + withOssDataflow = true, + semantics = NoCrossTaintSemantics.where(_.fullName.contains("bar.py")).after(DefaultSemantics()) + ) { + + "NoCrossTaintSemantics prevents cross-tainting arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |bar.foo(b, Z=a) + |bar.baz(b) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("baz").argument.argumentIndex(1) + sink.reachableByFlows(source).map(flowToResultPairs) shouldBe empty + } + + "NoCrossTaintSemantics prevents cross-tainting same-call named-arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |b = 2 + |bar.foo(X=a, Y=b) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("foo").argument.argumentName("Y") + sink.reachableByFlows(source) shouldBe empty + } + + "NoCrossTaintSemantics prevents cross-tainting same-call arguments to external method calls" in { + val cpg = code(""" + |import bar + |a = 1 + |b = 2 + |bar.foo(A=b, a) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call("foo").argument.argumentName("A") + sink.reachableByFlows(source) shouldBe empty + } + + "NoCrossTaintSemantics taints return values" in { + val cpg = code(""" + |import bar + |a = 1 + |b = 2 + |c = bar.foo(X=a, b) + |print(c) + |""".stripMargin) + val source = cpg.literal.lineNumber(3, 4) + val sink = cpg.call("print").argument + sink.reachableByFlows(source).map(flowToResultPairs).toSetMutable shouldBe Set( + List(("a = 1", 3), ("bar.foo(b, X = a)", 5), ("c = bar.foo(b, X = a)", 5), ("print(c)", 6)), + List(("b = 2", 4), ("bar.foo(b, X = a)", 5), ("c = bar.foo(b, X = a)", 5), ("print(c)", 6)) + ) + } +} + +class NoCrossTaintDataFlowTest2 + extends PySrc2CpgFixture( + withOssDataflow = true, + semantics = NoCrossTaintSemantics.where(_.fullName.contains("foo")).after(DefaultSemantics()) + ) { + + "NoCrossTaintSemantics works for specific external method call" in { + val cpg = code(""" + |import bar + |a = 1 + |bar.foo(a,b) # foo has no-cross-taint semantics, so b is not tainted by a + |bar.baz(a,c) # however, baz has default semantics, so c is tainted by a + |print(b) + |print(c) + |""".stripMargin) + val source = cpg.literal("1") + val sink = cpg.call.name("print").argument.argumentIndex(1) + // Note: it's unfortunate that `(bar.foo(a, b), 4)` still shows up in this flow. + // However, we can check that NoCrossTaintSemantics is doing its job, as otherwise + // we'd also have a `print(b)` sink. + sink.reachableByFlows(source).map(flowToResultPairs).l shouldBe List( + List(("a = 1", 3), ("bar.foo(a, b)", 4), ("bar.baz(a, c)", 5), ("print(c)", 7)) + ) + } + } class RegexDefinedFlowsDataFlowTests extends PySrc2CpgFixture( withOssDataflow = true, - extraFlows = List( - FlowSemantic.from("^path.*\\.sanitizer$", List((0, 0), (1, 1)), regex = true), - FlowSemantic.from("^foo.*\\.sanitizer.*", List((0, 0), (1, 1)), regex = true), - FlowSemantic.from("^foo.*\\.create_sanitizer\\.\\.sanitize", List((0, 0), (1, 1)), regex = true), - FlowSemantic - .from( - "requests.py:.post", - List((0, 0), (1, "url", -1), (2, "body", -1), (1, "url", 1, "url"), (2, "body", 2, "body")) - ), - FlowSemantic.from("cross_taint.py:.go", List((0, 0), (1, 1), (1, "a", 2, "b"))) + semantics = DefaultSemantics().plus( + List( + FlowSemantic.from("^path.*\\.sanitizer$", List((0, 0), (1, 1)), regex = true), + FlowSemantic.from("^foo.*\\.sanitizer.*", List((0, 0), (1, 1)), regex = true), + FlowSemantic.from("^foo.*\\.create_sanitizer\\.\\.sanitize", List((0, 0), (1, 1)), regex = true), + FlowSemantic + .from( + "requests.py:.post", + List((0, 0), (1, "url", -1), (2, "body", -1), (1, "url", 1, "url"), (2, "body", 2, "body")) + ), + FlowSemantic.from("cross_taint.py:.go", List((0, 0), (1, 1), (1, "a", 2, "b"))) + ) ) ) { @@ -969,8 +1171,8 @@ class RegexDefinedFlowsDataFlowTests |print(Foo.func()) |""".stripMargin) "be found" in { - val src = cpg.call.code("Foo.func").l - val snk = cpg.call("print").l + val src = cpg.call.code("Foo.func") + val snk = cpg.call("print") snk.argument.reachableByFlows(src).size shouldBe 1 } } @@ -984,8 +1186,8 @@ class RegexDefinedFlowsDataFlowTests |""".stripMargin) "be found" in { val src = cpg.identifier("Foo").l - val snk = cpg.call("print").l - snk.reachableByFlows(src).size shouldBe 2 + val snk = cpg.call("print").argument(1).l + snk.reachableByFlows(src).size shouldBe 3 } } "Import statement with method ref sample four" in { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/io/PySrc2CpgHTTPServerTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/io/PySrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..732674ecdeb0 --- /dev/null +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/io/PySrc2CpgHTTPServerTests.scala @@ -0,0 +1,82 @@ +package io.joern.pysrc2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class PySrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("pysrc2cpgTestsHttpTest") + val file = dir / "main.py" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |def main(): + | print($indexStr) + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.pysrc2cpg.NewMain.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.pysrc2cpg.NewMain.stop() + } + + "Using pysrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("pysrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("print()") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("pysrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain(s"print($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ConfigPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ConfigPassTests.scala index 930affce1d65..c8c993f5d995 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ConfigPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ConfigPassTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture +import io.shiftleft.semanticcpg.language.* class ConfigPassTests extends PySrc2CpgFixture(withOssDataflow = false) { @@ -12,7 +12,51 @@ class ConfigPassTests extends PySrc2CpgFixture(withOssDataflow = false) { c.content shouldBe "Flask==1.1.2" c.name shouldBe "requirements.txt" } + } + + "Pipfile should be included" in { + val cpg = code( + """ + |[[source]] + |url = "https://pypi.org/simple" + |verify_ssl = true + |name = "pypi" + |""".stripMargin, + "Pipfile" + ) + + val config = cpg.configFile.name("Pipfile").head + config.content should include("verify_ssl = true") + } + + "Pipfile.lock should be included" in { + val cpg = code( + """ + |{ + | "_meta": { + | "hash": { + | "sha256": "293ad83ead15eb7bfef8a768f1853fc4cfa31b32ab85ae6962a2630b57cf569b" + | }, + | "pipfile-spec": 6, + | "requires": { + | "python_full_version": "3.8.18", + | "python_version": "3.8" + | }, + | "sources": [ + | { + | "name": "pypi", + | "url": "https://pypi.org/simple", + | "verify_ssl": true + | } + | ] + | } + |} + |""".stripMargin, + "Pipfile.lock" + ) + val config = cpg.configFile.name("Pipfile.lock").head + config.content should include("\"name\": \"pypi\"") } } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/DynamicTypeHintFullNamePassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/DynamicTypeHintFullNamePassTests.scala index 284fb6b27408..8db531070350 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/DynamicTypeHintFullNamePassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/DynamicTypeHintFullNamePassTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ImportsPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ImportsPassTests.scala index 1a8f338cb498..8cb67c3aee6a 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ImportsPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/ImportsPassTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture +import io.shiftleft.semanticcpg.language.* class ImportsPassTests extends PySrc2CpgFixture(withOssDataflow = false) { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/InheritanceFullNamePassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/InheritanceFullNamePassTests.scala index 71e90af7761e..10611836281c 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/InheritanceFullNamePassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/InheritanceFullNamePassTests.scala @@ -1,7 +1,7 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.semanticcpg.language._ +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala index 8ae43b3361ba..a564993e2571 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala @@ -1,6 +1,6 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.testfixtures.PySrc2CpgFixture import io.joern.x2cpg.passes.frontend.XTypeHintCallLinker import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Member} @@ -1051,7 +1051,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "assert the method properties in RedisDB, especially quoted type hints" in { val Some(redisDB) = cpg.typeDecl.nameExact("RedisDB").method.nameExact("").headOption: @unchecked - val List(instanceM, getRedisM, setM) = redisDB.astOut.isMethod.nameExact("instance", "get_redis", "set").l + val List(instanceM, getRedisM, setM) = redisDB.astChildren.isMethod.nameExact("instance", "get_redis", "set").l instanceM.methodReturn.typeFullName shouldBe Seq("db", "redis.py:.RedisDB").mkString(File.separator) getRedisM.methodReturn.typeFullName shouldBe "aioredis.py:.Redis" @@ -1324,12 +1324,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { val variables = cpg.moduleVariables .where(_.typeFullName(".*FastAPI.*")) .l - val appIncludeRouterCalls = - variables.invokingCalls - .nameExact("include_router") - .l - val includedRouters = appIncludeRouterCalls.argument.argumentIndexGte(1).moduleVariables.l - val definitionsOfRouters = includedRouters.definitions.whereNot(_.source.isCall.nameExact("import")).l + val appIncludeRouterCalls = variables.invokingCalls.nameExact("include_router") + val includedRouters = appIncludeRouterCalls.argument.argumentIndexGte(1).moduleVariables + val definitionsOfRouters = includedRouters.definitions.whereNot(_.source.isCall.nameExact("import")) val List(adminRouter, normalRouter, itemsRouter) = definitionsOfRouters.map(x => (x.code, x.method.fullName)).sortBy(_._1).l: @unchecked @@ -1641,4 +1638,28 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { } } } + + "external method named `import_table`" should { + val cpg = code(""" + |import boto3 + |client = boto3.client("s3") + |response = client.import_table() + |""".stripMargin) + + "have correct methodFullName for `import_table" in { + cpg.call("import_table").l match { + case List(importTable) => + importTable.methodFullName shouldBe "boto3.py:.client..import_table" + case result => fail(s"Expected single call to import_table, but got $result") + } + } + + "provide meaningful typeFullName for `response`" in { + cpg.assignment.target.isIdentifier.name("response").l match { + case List(response) => + response.typeFullName shouldBe "boto3.py:.client..import_table." + case result => fail(s"Expected single assignment to response, but got $result") + } + } + } } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/Py2CpgTestContext.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/Py2CpgTestContext.scala similarity index 95% rename from joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/Py2CpgTestContext.scala rename to joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/Py2CpgTestContext.scala index 6b3facb4d231..ff1526453a45 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/Py2CpgTestContext.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/Py2CpgTestContext.scala @@ -1,5 +1,6 @@ -package io.joern.pysrc2cpg +package io.joern.pysrc2cpg.testfixtures +import io.joern.pysrc2cpg.Py2Cpg import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.X2Cpg.defaultOverlayCreators import io.shiftleft.codepropertygraph.generated.Cpg diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/PySrc2CpgFixture.scala similarity index 60% rename from joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala rename to joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/PySrc2CpgFixture.scala index a9b8a22e450a..177f456fde34 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/testfixtures/PySrc2CpgFixture.scala @@ -1,26 +1,26 @@ -package io.joern.pysrc2cpg +package io.joern.pysrc2cpg.testfixtures import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.Path -import io.joern.dataflowengineoss.layers.dataflows.* -import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} -import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} -import io.joern.x2cpg.X2Cpg -import io.joern.x2cpg.frontendspecific.pysrc2cpg.{ - DynamicTypeHintFullNamePass, - ImportsPass, - PythonImportResolverPass, - PythonInheritanceNamePass, - PythonTypeHintCallLinker, - PythonTypeRecoveryPassGenerator -} +import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.testfixtures.SemanticCpgTestFixture +import io.joern.dataflowengineoss.testfixtures.SemanticTestCpg +import io.joern.pysrc2cpg.Py2CpgOnFileSystem +import io.joern.pysrc2cpg.Py2CpgOnFileSystemConfig +import io.joern.x2cpg.frontendspecific.pysrc2cpg.DynamicTypeHintFullNamePass +import io.joern.x2cpg.frontendspecific.pysrc2cpg.ImportsPass +import io.joern.x2cpg.frontendspecific.pysrc2cpg.PythonImportResolverPass +import io.joern.x2cpg.frontendspecific.pysrc2cpg.PythonInheritanceNamePass +import io.joern.x2cpg.frontendspecific.pysrc2cpg.PythonTypeHintCallLinker +import io.joern.x2cpg.frontendspecific.pysrc2cpg.PythonTypeRecoveryPassGenerator import io.joern.x2cpg.passes.base.AstLinkerPass import io.joern.x2cpg.passes.callgraph.NaiveCallLinker -import io.joern.x2cpg.testfixtures.{Code2CpgFixture, DefaultTestCpg, LanguageFrontend, TestCpg} +import io.joern.x2cpg.testfixtures.Code2CpgFixture +import io.joern.x2cpg.testfixtures.DefaultTestCpg +import io.joern.x2cpg.testfixtures.LanguageFrontend import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} -import io.shiftleft.semanticcpg.layers.LayerCreatorContext +import io.shiftleft.semanticcpg.language.ICallResolver +import io.shiftleft.semanticcpg.language.NoResolve trait PythonFrontend extends LanguageFrontend { override val fileSuffix: String = ".py" @@ -48,7 +48,7 @@ class PySrcTestCpg extends DefaultTestCpg with PythonFrontend with SemanticTestC new PythonTypeHintCallLinker(this).createAndApply() new NaiveCallLinker(this).createAndApply() - // Some of passes above create new methods, so, we + // Some of the passes above create new methods, so, we // need to run the ASTLinkerPass one more time new AstLinkerPass(this).createAndApply() applyOssDataFlow() @@ -58,20 +58,20 @@ class PySrcTestCpg extends DefaultTestCpg with PythonFrontend with SemanticTestC class PySrc2CpgFixture( withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), withPostProcessing: Boolean = true ) extends Code2CpgFixture(() => new PySrcTestCpg() .withOssDataflow(withOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) - with SemanticCpgTestFixture(extraFlows) { + with SemanticCpgTestFixture(semantics) { implicit val resolver: ICallResolver = NoResolve protected def flowToResultPairs(path: Path): List[(String, Integer)] = - path.resultPairs().collect { case (firstElement: String, secondElement: Option[Integer]) => + path.resultPairs().collect { case (firstElement: String, secondElement) => (firstElement, secondElement.getOrElse(-1)) } } diff --git a/joern-cli/frontends/rubysrc2cpg/.gitignore b/joern-cli/frontends/rubysrc2cpg/.gitignore index 619f14a2c217..a086a83c07a3 100644 --- a/joern-cli/frontends/rubysrc2cpg/.gitignore +++ b/joern-cli/frontends/rubysrc2cpg/.gitignore @@ -2,3 +2,4 @@ gen/ *.tokens type_stubs +src/main/resources/ruby_ast_gen diff --git a/joern-cli/frontends/rubysrc2cpg/README.md b/joern-cli/frontends/rubysrc2cpg/README.md new file mode 100644 index 000000000000..a7b4f98acf12 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/README.md @@ -0,0 +1,8 @@ +# rubysrc2cpg + +A `parser` Gem based parser for Ruby source code that creates code property graphs according to the specification at https://github.com/ShiftLeftSecurity/codepropertygraph . + +The `parser` Gem is wrapped around a Ruby application [ruby_ast_gen](https://github.com/joernio/ruby_ast_gen) that is +then embedded under `src/main/resources` and executed during runtime using JRuby. + +To update this, set the version under `src/main/resources/application.conf` and run `sbt rubysrc2cpg/astGenDlTask`. \ No newline at end of file diff --git a/joern-cli/frontends/rubysrc2cpg/build.sbt b/joern-cli/frontends/rubysrc2cpg/build.sbt index 9f2372e82bb6..c1c4cfef8e3b 100644 --- a/joern-cli/frontends/rubysrc2cpg/build.sbt +++ b/joern-cli/frontends/rubysrc2cpg/build.sbt @@ -1,4 +1,10 @@ +import better.files import com.typesafe.config.{Config, ConfigFactory} +import versionsort.VersionHelper + +import java.net.URI +import scala.sys.process.stringToProcess +import scala.util.{Failure, Success, Try} name := "rubysrc2cpg" @@ -6,7 +12,7 @@ dependsOn(Projects.dataflowengineoss % "compile->compile;test->test", Projects.x lazy val appProperties = settingKey[Config]("App Properties") appProperties := { - val path = (Compile / resourceDirectory).value / "application.conf" + val path = (Compile / resourceDirectory).value / "application.conf" val applicationConf = ConfigFactory.parseFile(path).resolve() applicationConf } @@ -15,17 +21,57 @@ lazy val joernTypeStubsVersion = settingKey[String]("joern_type_stub version") joernTypeStubsVersion := appProperties.value.getString("rubysrc2cpg.joern_type_stubs_version") libraryDependencies ++= Seq( - "io.shiftleft" %% "codepropertygraph" % Versions.cpg, + "io.shiftleft" %% "codepropertygraph" % Versions.cpg, "org.apache.commons" % "commons-compress" % Versions.commonsCompress, // For unpacking Gems with `--download-dependencies` - "org.scalatest" %% "scalatest" % Versions.scalatest % Test, - "org.antlr" % "antlr4-runtime" % Versions.antlr + "org.jruby" % "jruby-complete" % Versions.jRuby, + "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) -enablePlugins(JavaAppPackaging, LauncherJarPlugin, Antlr4Plugin) +enablePlugins(JavaAppPackaging, LauncherJarPlugin) -Antlr4 / antlr4Version := Versions.antlr -Antlr4 / antlr4GenVisitor := true -Antlr4 / javaSource := (Compile / sourceManaged).value +libraryDependencies ++= Seq( + "io.shiftleft" %% "codepropertygraph" % Versions.cpg, + "org.scalatest" %% "scalatest" % Versions.scalatest % Test +) + +lazy val astGenVersion = settingKey[String]("`ruby_ast_gen` version") +astGenVersion := appProperties.value.getString("rubysrc2cpg.ruby_ast_gen_version") + +lazy val astGenDlUrl = settingKey[String]("astgen download url") +astGenDlUrl := s"https://github.com/joernio/ruby_ast_gen/releases/download/v${astGenVersion.value}/" + +def hasCompatibleAstGenVersion(astGenBaseDir: File, astGenVersion: String): Boolean = { + val versionFile = astGenBaseDir / "lib" / "ruby_ast_gen" / "version.rb" + if (!versionFile.exists) return false + val versionPattern = "VERSION = \"([0-9]+\\.[0-9]+\\.[0-9]+)\"".r + versionPattern.findFirstIn(IO.read(versionFile)) match { + case Some(versionString) => + // Regex group matching doesn't appear to work in SBT + val version = versionString.stripPrefix("VERSION = \"").stripSuffix("\"") + version == astGenVersion + case _ => false + } +} + +lazy val astGenResourceTask = taskKey[Seq[File]](s"Download `ruby_ast_gen` and package this under `resources`") +astGenResourceTask := { + val targetDir = baseDirectory.value / "src" / "main" / "resources" + val gemName = s"ruby_ast_gen_v${astGenVersion.value}.zip" + val compressGemPath = targetDir / gemName + val unpackedGemFullPath = targetDir / gemName.stripSuffix(s"_v${astGenVersion.value}.zip") + if (!hasCompatibleAstGenVersion(unpackedGemFullPath, astGenVersion.value)) { + if (unpackedGemFullPath.exists()) IO.delete(unpackedGemFullPath) + val url = s"${astGenDlUrl.value}$gemName" + sbt.io.Using.urlInputStream(new URI(url).toURL) { inputStream => + sbt.IO.transfer(inputStream, compressGemPath) + } + IO.unzip(compressGemPath, unpackedGemFullPath) + IO.delete(compressGemPath) + } + (unpackedGemFullPath ** "*").get.filter(_.isFile) +} + +Compile / resourceGenerators += astGenResourceTask lazy val joernTypeStubsDlUrl = settingKey[String]("joern_type_stubs download url") joernTypeStubsDlUrl := s"https://github.com/joernio/joern-type-stubs/releases/download/v${joernTypeStubsVersion.value}/" @@ -33,8 +79,8 @@ joernTypeStubsDlUrl := s"https://github.com/joernio/joern-type-stubs/releases/do lazy val joernTypeStubsDlTask = taskKey[Unit]("Download joern-type-stubs") joernTypeStubsDlTask := { val joernTypeStubsDir = baseDirectory.value / "type_stubs" - val fileName = "rubysrc_builtin_types.zip" - val shaFileName = s"$fileName.sha512" + val fileName = "rubysrc_builtin_types.zip" + val shaFileName = s"$fileName.sha512" joernTypeStubsDir.mkdir() @@ -42,7 +88,7 @@ joernTypeStubsDlTask := { DownloadHelper.ensureIsAvailable(s"${joernTypeStubsDlUrl.value}$shaFileName", joernTypeStubsDir / shaFileName) val typeStubsFile = better.files.File(joernTypeStubsDir.getAbsolutePath) / fileName - val checksumFile = better.files.File(joernTypeStubsDir.getAbsolutePath) / shaFileName + val checksumFile = better.files.File(joernTypeStubsDir.getAbsolutePath) / shaFileName val typestubsSha = typeStubsFile.sha512 @@ -64,5 +110,9 @@ joernTypeStubsDlTask := { Compile / compile := ((Compile / compile) dependsOn joernTypeStubsDlTask).value -Universal / packageName := name.value -Universal / topLevelDirectory := None \ No newline at end of file +Universal / packageName := name.value +Universal / topLevelDirectory := None + +/** write the astgen version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("Ruby-AstGen-Version") -> astGenVersion.value) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexer.g4 b/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexer.g4 deleted file mode 100644 index 0aef432c8278..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexer.g4 +++ /dev/null @@ -1,897 +0,0 @@ -lexer grammar DeprecatedRubyLexer; - -// -------------------------------------------------------- -// Auxiliary tokens and features -// -------------------------------------------------------- - -@header { - package io.joern.rubysrc2cpg.deprecated.parser; -} - -tokens { - STRING_INTERPOLATION_END, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_START, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - QUOTED_EXPANDED_STRING_LITERAL_END, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - DELIMITED_STRING_INTERPOLATION_END, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - - // The following tokens are created by `RubyLexerPostProcessor` only. - NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE, - EXPANDED_LITERAL_CHARACTER_SEQUENCE -} - -options { - superClass = DeprecatedRubyLexerBase; -} - -// -------------------------------------------------------- -// Keywords -// -------------------------------------------------------- - -LINE__:'__LINE__'; -ENCODING__: '__ENCODING__'; -FILE__: '__FILE__'; -BEGIN_: 'BEGIN'; -END_: 'END'; -ALIAS: 'alias'; -AND: 'and'; -BEGIN: 'begin'; -BREAK: 'break'; -CASE: 'case'; -CLASS: 'class'; -DEF: 'def'; -IS_DEFINED: 'defined?'; -DO: 'do'; -ELSE: 'else'; -ELSIF: 'elsif'; -END: 'end'; -ENSURE: 'ensure'; -FOR: 'for'; -FALSE: 'false'; -IF: 'if'; -IN: 'in'; -MODULE: 'module'; -NEXT: 'next'; -NIL: 'nil'; -NOT: 'not'; -OR: 'or'; -REDO: 'redo'; -RESCUE: 'rescue'; -RETRY: 'retry'; -RETURN: 'return'; -SELF: 'self'; -SUPER: 'super'; -THEN: 'then'; -TRUE: 'true'; -UNDEF: 'undef'; -UNLESS: 'unless'; -UNTIL: 'until'; -WHEN: 'when'; -WHILE: 'while'; -YIELD: 'yield'; - -fragment KEYWORD - : LINE__ - | ENCODING__ - | FILE__ - | BEGIN_ - | END_ - | ALIAS - | AND - | BEGIN - | BREAK - | CASE - | CLASS - | DEF - | IS_DEFINED - | DO - | ELSE - | ELSIF - | END - | ENSURE - | FOR - | FALSE - | IF - | IN - | MODULE - | NEXT - | NIL - | NOT - | OR - | REDO - | RESCUE - | RETRY - | RETURN - | SELF - | SUPER - | THEN - | TRUE - | UNDEF - | UNLESS - | UNTIL - | WHEN - | WHILE - | YIELD - ; - -// -------------------------------------------------------- -// Punctuators -// -------------------------------------------------------- - -LBRACK: '['; -RBRACK: ']'; -LPAREN: '('; -RPAREN: ')'; -LCURLY: '{'; -RCURLY: '}' - { - if (isEndOfInterpolation()) { - popMode(); - setType(popInterpolationEndTokenType()); - } - } -; -COLON: ':'; -COLON2: '::'; -COMMA: ','; -SEMI: ';'; -DOT: '.'; -DOT2: '..'; -DOT3: '...'; -QMARK: '?'; -EQGT: '=>'; -MINUSGT: '->'; - -fragment PUNCTUATOR - : LBRACK - | RBRACK - | LPAREN - | RPAREN - | LCURLY - | RCURLY - | COLON2 - | COMMA - | SEMI - | DOT2 - | DOT3 - | QMARK - | COLON - | EQGT - ; - -// -------------------------------------------------------- -// Operators -// -------------------------------------------------------- - -EMARK: '!'; -EMARKEQ: '!='; -EMARKTILDE: '!~'; -AMP: '&'; -AMP2: '&&'; -AMPDOT: '&.'; -BAR: '|'; -BAR2: '||'; -EQ: '='; -EQ2: '=='; -EQ3: '==='; -CARET: '^'; -LTEQGT: '<=>'; -EQTILDE: '=~'; -GT: '>'; -GTEQ: '>='; -LT: '<'; -LTEQ: '<='; -LT2: '<<'; -GT2: '>>'; -PLUS: '+'; -MINUS: '-'; -STAR: '*'; -STAR2: '**'; -SLASH: '/' - { - if (isStartOfRegexLiteral()) { - setType(REGULAR_EXPRESSION_START); - pushMode(REGULAR_EXPRESSION_MODE); - } - } -; -PERCENT: '%'; -TILDE: '~'; -// These tokens should only occur after a DEF token, as they are solely used to (re)define unary + and - operators. -// This way we won't emit the wrong token in e.g. `x+@y` (which means + between x and @y) -PLUSAT: '+@' {previousNonWsTokenTypeOrEOF() == DEF}?; -MINUSAT: '-@' {previousNonWsTokenTypeOrEOF() == DEF}?; - -ASSIGNMENT_OPERATOR - : ASSIGNMENT_OPERATOR_NAME '=' - ; - -fragment ASSIGNMENT_OPERATOR_NAME - : AMP - | AMP2 - | BAR - | BAR2 - | CARET - | LT2 - | GT2 - | PLUS - | MINUS - | STAR - | STAR2 - | PERCENT - | SLASH - ; - -fragment OPERATOR_METHOD_NAME - : CARET - | AMP - | BAR - | LTEQGT - | EQ2 - | EQ3 - | EQTILDE - | GT - | GTEQ - | LT - | LTEQ - | LT2 - | GT2 - | PLUS - | MINUS - | STAR - | SLASH - | PERCENT - | STAR2 - | TILDE - | PLUSAT - | MINUSAT - | '[]' - | '[]=' - ; - -// -------------------------------------------------------- -// String literals -// -------------------------------------------------------- - -SINGLE_QUOTED_STRING_LITERAL - : '\'' SINGLE_QUOTED_STRING_CHARACTER*? '\'' - ; - -fragment SINGLE_QUOTED_STRING_CHARACTER - : SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - | SINGLE_QUOTED_ESCAPE_SEQUENCE - ; - -fragment SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - : ~['\\] - ; - -fragment SINGLE_QUOTED_ESCAPE_SEQUENCE - : SINGLE_ESCAPE_CHARACTER_SEQUENCE - | SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER_SEQUENCE - ; - -fragment SINGLE_ESCAPE_CHARACTER_SEQUENCE - : '\\' SINGLE_QUOTED_STRING_META_CHARACTER - ; - -fragment SINGLE_QUOTED_STRING_META_CHARACTER - : ['\\] - ; - -fragment SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER_SEQUENCE - : '\\' SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - ; - -DOUBLE_QUOTED_STRING_START - : '"' - -> pushMode(DOUBLE_QUOTED_STRING_MODE) - ; - -QUOTED_NON_EXPANDED_STRING_LITERAL_START - : '%q' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_STRING_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_STRING_MODE) - ; - -QUOTED_EXPANDED_STRING_LITERAL_START - : '%Q' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_LITERAL_END); - _input.consume(); - pushMode(EXPANDED_DELIMITED_STRING_MODE); - } - // This check exists to prevent issuing a QUOTED_EXPANDED_STRING_LITERAL_START - // in obvious arithmetic expressions, such as `20 %(x+1)`. - // Note, however, that we can't have a perfect test at this stage. For instance, - // in `x = 1; x %(2)`, it's clear that's an arithmetic expression, but we - // will still emit a QUOTED_EXPANDED_STRING_LITERAL_START. - | '%(' {!isNumericTokenType(previousTokenTypeOrEOF())}? - { - pushQuotedDelimiter('('); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_LITERAL_END); - pushMode(EXPANDED_DELIMITED_STRING_MODE); - } - ; - -QUOTED_EXPANDED_REGULAR_EXPRESSION_START - : '%r' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_REGULAR_EXPRESSION_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_STRING_MODE) - ; - -QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START - : '%x' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_STRING_MODE) - ; - -// -------------------------------------------------------- -// String (Word) array literals -// -------------------------------------------------------- - -QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START - : '%w' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_ARRAY_MODE) - ; - -QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START - : '%W' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_ARRAY_MODE) - ; - -// -------------------------------------------------------- -// Here doc literals -// -------------------------------------------------------- - -HERE_DOC_IDENTIFIER - : '<<' [-~]? [\t]* IDENTIFIER - ; - -HERE_DOC - : '<<' [-~]? [\t]* IDENTIFIER [a-zA-Z_0-9]* NL ( {!heredocEndAhead(getText())}? . )* [a-zA-Z_] [a-zA-Z_0-9]* - ; - -// -------------------------------------------------------- -// Symbol array literals -// -------------------------------------------------------- - -QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START - : '%i' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_ARRAY_MODE) - ; - -QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START - : '%I' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_ARRAY_MODE) - ; - -// -------------------------------------------------------- -// Data section -// -------------------------------------------------------- - -END_OF_PROGRAM_MARKER - : '__END__' {getCharPositionInLine() == 7}? '\r'? '\n' - -> pushMode(DATA_SECTION_MODE), skip - ; - -// -------------------------------------------------------- -// Numeric literals -// -------------------------------------------------------- - -DECIMAL_INTEGER_LITERAL - : UNPREFIXED_DECIMAL_INTEGER_LITERAL - | PREFIXED_DECIMAL_INTEGER_LITERAL - ; - -BINARY_INTEGER_LITERAL - : '0' [bB] BINARY_DIGIT ('_'? BINARY_DIGIT)* - ; - -OCTAL_INTEGER_LITERAL - : '0' [_oO]? OCTAL_DIGIT ('_'? OCTAL_DIGIT)* - ; - -HEXADECIMAL_INTEGER_LITERAL - : '0' [xX] HEXADECIMAL_DIGIT ('_'? HEXADECIMAL_DIGIT)* - ; - -FLOAT_LITERAL_WITHOUT_EXPONENT - : UNPREFIXED_DECIMAL_INTEGER_LITERAL '.' DIGIT_DECIMAL_PART - ; - -FLOAT_LITERAL_WITH_EXPONENT - : SIGNIFICAND_PART EXPONENT_PART - ; - -fragment UNPREFIXED_DECIMAL_INTEGER_LITERAL - : '0' - | DECIMAL_DIGIT_EXCEPT_0 ('_'? DECIMAL_DIGIT)* - ; - -fragment PREFIXED_DECIMAL_INTEGER_LITERAL - : '0' [dD] DIGIT_DECIMAL_PART - ; - -fragment SIGNIFICAND_PART - : FLOAT_LITERAL_WITHOUT_EXPONENT - | UNPREFIXED_DECIMAL_INTEGER_LITERAL - ; - -fragment EXPONENT_PART - : [eE] ('+' | '-')? DIGIT_DECIMAL_PART - ; - -fragment BINARY_DIGIT - : [0-1] - ; - -fragment OCTAL_DIGIT - : [0-7] - ; - -fragment DIGIT_DECIMAL_PART - : DECIMAL_DIGIT ('_'? DECIMAL_DIGIT)* - ; - -fragment DECIMAL_DIGIT - : [0-9] - ; - -fragment DECIMAL_DIGIT_EXCEPT_0 - : [1-9] - ; - -fragment HEXADECIMAL_DIGIT - : DECIMAL_DIGIT - | [a-f] - | [A-F] - ; - -// -------------------------------------------------------- -// Whitespaces -// -------------------------------------------------------- - -NL: LINE_TERMINATOR+; -WS: WHITESPACE+; - -fragment WHITESPACE - : [\u0009] - | [\u000b] - | [\u000c] - | [\u000d] - | [\u0020] - | LINE_TERMINATOR_ESCAPE_SEQUENCE - ; - -fragment LINE_TERMINATOR_ESCAPE_SEQUENCE - : '\\' LINE_TERMINATOR - ; - -fragment LINE_TERMINATOR - : '\r'? '\n' - ; - -// -------------------------------------------------------- -// Symbols -// -------------------------------------------------------- - -SYMBOL_LITERAL - : ':' (SYMBOL_NAME | (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) '=') - // This check exists to prevent issuing a SYMBOL_LITERAL in whitespace-free associations, e.g. - // in `foo(x:y)`, so that `:y` is not a SYMBOL_LITERAL - // or in `{:x=>1}`, so that `:x=` is not a SYMBOL_LITERAL - {previousTokenTypeOrEOF() != LOCAL_VARIABLE_IDENTIFIER && _input.LA(1) != '>'}? - ; - -fragment SYMBOL_NAME - : INSTANCE_VARIABLE_IDENTIFIER - | GLOBAL_VARIABLE_IDENTIFIER - | CLASS_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | LOCAL_VARIABLE_IDENTIFIER - | METHOD_ONLY_IDENTIFIER - | OPERATOR_METHOD_NAME - | KEYWORD - // NOTE: Even though we have PLUSAT and MINUSAT in OPERATOR_METHOD_NAME, the former - // are not emitted unless there's a DEF token before them, cf. their predicate. - // Thus, we need to add them explicitly here in order to recognize standalone SYMBOL_LITERAL tokens as well. - | '+@' - | '-@' - ; - -// -------------------------------------------------------- -// Identifiers -// -------------------------------------------------------- - -LOCAL_VARIABLE_IDENTIFIER - : (LOWERCASE_CHARACTER | '_') IDENTIFIER_CHARACTER* - ; - -GLOBAL_VARIABLE_IDENTIFIER - : '$' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - | '$' [0-9]+ - ; - -INSTANCE_VARIABLE_IDENTIFIER - : '@' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - ; - -CLASS_VARIABLE_IDENTIFIER - : '@@' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - ; - -CONSTANT_IDENTIFIER - : UPPERCASE_CHARACTER IDENTIFIER_CHARACTER* - ; - -fragment METHOD_ONLY_IDENTIFIER - : (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) ('!' | '?') - ; - - -// Similarly to PLUSAT/MINUSAT, this should only occur after a DEF token. -// Otherwise, the assignment `x=nil` would be parsed as (ASSIGNMENT_LIKE_METHOD_IDENTIFIER, NIL) -// instead of the more appropriate (LOCAL_VARIABLE_IDENTIFIER, EQ, NIL). -ASSIGNMENT_LIKE_METHOD_IDENTIFIER - : (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) '=' {previousNonWsTokenTypeOrEOF() == DEF}? - ; - -fragment IDENTIFIER_CHARACTER - : IDENTIFIER_START_CHARACTER - | DECIMAL_DIGIT - ; - -fragment IDENTIFIER_START_CHARACTER - : LOWERCASE_CHARACTER - | UPPERCASE_CHARACTER - | '_' - ; - -fragment LOWERCASE_CHARACTER - : [a-z] - ; - -fragment UPPERCASE_CHARACTER - : [A-Z] - ; - -fragment IDENTIFIER - : LOCAL_VARIABLE_IDENTIFIER - | GLOBAL_VARIABLE_IDENTIFIER - | CLASS_VARIABLE_IDENTIFIER - | INSTANCE_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | METHOD_ONLY_IDENTIFIER - | ASSIGNMENT_LIKE_METHOD_IDENTIFIER - ; - -// -------------------------------------------------------- -// Comments (are skipped) -// -------------------------------------------------------- - -SINGLE_LINE_COMMENT - : '#' COMMENT_CONTENT? - -> skip; - -MULTI_LINE_COMMENT - : MULTI_LINE_COMMENT_BEGIN_LINE .*? MULTI_LINE_COMMENT_END_LINE - -> skip; - -fragment COMMENT_CONTENT - : (~[\r\n])+ // Meaning (~LINE_TERMINATOR)+ - ; - -fragment MULTI_LINE_COMMENT_BEGIN_LINE - : '=begin' {getCharPositionInLine() == 6}? REST_OF_BEGIN_END_LINE? LINE_TERMINATOR - ; - -fragment MULTI_LINE_COMMENT_END_LINE - : '=end' {getCharPositionInLine() == 4}? REST_OF_BEGIN_END_LINE? (LINE_TERMINATOR | EOF) - ; - -fragment REST_OF_BEGIN_END_LINE - : WHITESPACE+ COMMENT_CONTENT - ; - -// -------------------------------------------------------- -// Unrecognized characters -// -------------------------------------------------------- - -// Any other character shall still be recognized so that the -// recovery mechanism in `io.joern.rubysrc2cpg.astcreation.AntlrParser` -// also handles them. Otherwise, the lexer would complain, not emit -// and the recovery mechanism would not be able to act. - -// Note: this must be the very last rule in this lexer specification, as -// otherwise this token would take precedence over any token defined after. -UNRECOGNIZED - : . - ; - -// -------------------------------------------------------- -// Double quoted string mode -// -------------------------------------------------------- - -mode DOUBLE_QUOTED_STRING_MODE; - -DOUBLE_QUOTED_STRING_END - : '"' - -> popMode - ; - -DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE - : DOUBLE_QUOTED_STRING_CHARACTER+ - ; - -fragment INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - : '#' GLOBAL_VARIABLE_IDENTIFIER - | '#' CLASS_VARIABLE_IDENTIFIER - | '#' INSTANCE_VARIABLE_IDENTIFIER - ; - -INTERPOLATED_CHARACTER_SEQUENCE - : INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - ; - -STRING_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(STRING_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -fragment DOUBLE_QUOTED_STRING_CHARACTER - : ~["#\\] - | '#' {_input.LA(1) != '$' && _input.LA(1) != '@' && _input.LA(1) != '{'}? - | DOUBLE_ESCAPE_SEQUENCE - ; - -fragment DOUBLE_ESCAPE_SEQUENCE - : SIMPLE_ESCAPE_SEQUENCE - | NON_ESCAPED_SEQUENCE - | LINE_TERMINATOR_ESCAPE_SEQUENCE - | OCTAL_ESCAPE_SEQUENCE - | HEXADECIMAL_ESCAPE_SEQUENCE - | CONTROL_ESCAPE_SEQUENCE - ; - -fragment CONTROL_ESCAPE_SEQUENCE - : '\\' ('C-' | 'c') CONTROL_ESCAPED_CHARACTER - ; - -fragment CONTROL_ESCAPED_CHARACTER - : DOUBLE_ESCAPE_SEQUENCE - | '?' - | ~[\\?] - ; - -fragment OCTAL_ESCAPE_SEQUENCE - : '\\' OCTAL_DIGIT OCTAL_DIGIT? OCTAL_DIGIT? - ; - -fragment HEXADECIMAL_ESCAPE_SEQUENCE - : '\\x' HEXADECIMAL_DIGIT HEXADECIMAL_DIGIT? - ; - -fragment NON_ESCAPED_SEQUENCE - : '\\' NON_ESCAPED_DOUBLE_QUOTED_STRING_CHARACTER - ; - -fragment NON_ESCAPED_DOUBLE_QUOTED_STRING_CHARACTER - : ~[\r\nA-Za-z0-9] - ; - -fragment SIMPLE_ESCAPE_SEQUENCE - : '\\' DOUBLE_ESCAPED_CHARACTER - ; - -fragment DOUBLE_ESCAPED_CHARACTER - : [ntrfvaebsu] - ; - -// -------------------------------------------------------- -// Expanded delimited string mode -// -------------------------------------------------------- - -mode EXPANDED_DELIMITED_STRING_MODE; - -DELIMITED_STRING_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(DELIMITED_STRING_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -EXPANDED_VARIABLE_CHARACTER_SEQUENCE - : INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - ; - -EXPANDED_LITERAL_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Non-expanded delimited string mode -// -------------------------------------------------------- - -mode NON_EXPANDED_DELIMITED_STRING_MODE; - - -fragment NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - : '\\' NON_ESCAPED_LITERAL_CHARACTER - ; - -fragment NON_ESCAPED_LITERAL_CHARACTER - : ~[\r\n] - | '\n' {_input.LA(1) != '\r'}? - ; - -NON_EXPANDED_LITERAL_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Expanded delimited array mode -// -------------------------------------------------------- - -mode EXPANDED_DELIMITED_ARRAY_MODE; - -DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(DELIMITED_ARRAY_ITEM_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -EXPANDED_ARRAY_ITEM_SEPARATOR - : NON_EXPANDED_ARRAY_ITEM_DELIMITER - ; - -EXPANDED_ARRAY_ITEM_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Non-expanded delimited array mode -// -------------------------------------------------------- - -mode NON_EXPANDED_DELIMITED_ARRAY_MODE; - -fragment NON_EXPANDED_ARRAY_ITEM_DELIMITER - : [\u0009] - | [\u000a] - | [\u000b] - | [\u000c] - | [\u000d] - | [\u0020] - | '\\' ('\r'? '\n') - ; - -NON_EXPANDED_ARRAY_ITEM_SEPARATOR - : NON_EXPANDED_ARRAY_ITEM_DELIMITER - ; - -NON_EXPANDED_ARRAY_ITEM_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Regex literal mode -// -------------------------------------------------------- - -mode REGULAR_EXPRESSION_MODE; - -REGULAR_EXPRESSION_END - : '/' REGULAR_EXPRESSION_OPTION* - -> popMode - ; - -REGULAR_EXPRESSION_BODY - : REGULAR_EXPRESSION_CHARACTER+ - ; - -REGULAR_EXPRESSION_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(REGULAR_EXPRESSION_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -fragment REGULAR_EXPRESSION_OPTION - : [imxo] - ; - -fragment REGULAR_EXPRESSION_CHARACTER - : ~[/#\\] - | '#' {_input.LA(1) != '$' && _input.LA(1) != '@' && _input.LA(1) != '{'}? - | REGULAR_EXPRESSION_NON_ESCAPED_SEQUENCE - | REGULAR_EXPRESSION_ESCAPE_SEQUENCE - | LINE_TERMINATOR_ESCAPE_SEQUENCE - | INTERPOLATED_CHARACTER_SEQUENCE - ; - -fragment REGULAR_EXPRESSION_NON_ESCAPED_SEQUENCE - : '\\' REGULAR_EXPRESSION_NON_ESCAPED_CHARACTER - ; - -fragment REGULAR_EXPRESSION_NON_ESCAPED_CHARACTER - : ~[\r\n] - | '\n' {_input.LA(1) != '\r'}? - ; - -fragment REGULAR_EXPRESSION_ESCAPE_SEQUENCE - : '\\' '/' - ; - -// -------------------------------------------------------- -// Data section mode -// -------------------------------------------------------- - -mode DATA_SECTION_MODE; - -DATA_SECTION_CONTENT - : .*? EOF - -> popMode, skip - ; diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyParser.g4 b/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyParser.g4 deleted file mode 100644 index 6eb263e22a5e..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyParser.g4 +++ /dev/null @@ -1,758 +0,0 @@ -parser grammar DeprecatedRubyParser; - -@header { - package io.joern.rubysrc2cpg.deprecated.parser; -} - -options { - tokenVocab = DeprecatedRubyLexer; -} - -// -------------------------------------------------------- -// Program -// -------------------------------------------------------- - -program - : compoundStatement EOF - ; - -compoundStatement - : (SEMI | NL)* statements? (SEMI | NL)* - ; - -// -------------------------------------------------------- -// Statements -// -------------------------------------------------------- - -statements - : statement ((SEMI | NL)+ statement)* - ; - -statement - : ALIAS NL? definedMethodNameOrSymbol NL? definedMethodNameOrSymbol # aliasStatement - | UNDEF NL? definedMethodNameOrSymbol (COMMA NL? definedMethodNameOrSymbol)* # undefStatement - | BEGIN_ LCURLY compoundStatement RCURLY # beginStatement - | END_ LCURLY compoundStatement RCURLY # endStatement - | statement mod=(IF | UNLESS | WHILE | UNTIL | RESCUE) NL? statement # modifierStatement - | expressionOrCommand # expressionOrCommandStatement - ; - -// -------------------------------------------------------- -// Expressions -// -------------------------------------------------------- - -expressionOrCommand - : expression # expressionExpressionOrCommand - | (EMARK NL?)? invocationWithoutParentheses # invocationExpressionOrCommand - | NOT NL? expressionOrCommand # notExpressionOrCommand - | expressionOrCommand op=(OR | AND) NL? expressionOrCommand # orAndExpressionOrCommand - ; - -expression - : singleLeftHandSide op=(EQ | ASSIGNMENT_OPERATOR) NL? multipleRightHandSide # singleAssignmentExpression - | multipleLeftHandSide EQ NL? multipleRightHandSide # multipleAssignmentExpression - | primary # primaryExpression - | op=(TILDE | PLUS | EMARK) NL? expression # unaryExpression - | expression STAR2 NL? expression # powerExpression - | MINUS NL? expression # unaryMinusExpression - | expression op=(STAR | SLASH | PERCENT) NL? expression # multiplicativeExpression - | expression op=(PLUS | MINUS) NL? expression # additiveExpression - | expression op=(LT2 | GT2) NL? expression # bitwiseShiftExpression - | expression op=AMP NL? expression # bitwiseAndExpression - | expression op=(BAR | CARET) NL? expression # bitwiseOrExpression - | expression op=(GT | GTEQ | LT | LTEQ) NL? expression # relationalExpression - | expression op=(LTEQGT | EQ2 | EQ3 | EMARKEQ | EQTILDE | EMARKTILDE) NL? expression? # equalityExpression - | expression op=AMP2 NL? expression # operatorAndExpression - | expression op=BAR2 NL? expression # operatorOrExpression - | expression op=(DOT2 | DOT3) NL? expression? # rangeExpression - | expression QMARK NL? expression NL? COLON NL? expression # conditionalOperatorExpression - | IS_DEFINED NL? expression # isDefinedExpression - ; - -primary - : classDefinition # classDefinitionPrimary - | moduleDefinition # moduleDefinitionPrimary - | methodDefinition # methodDefinitionPrimary - | procDefinition # procDefinitionPrimary - | yieldWithOptionalArgument # yieldWithOptionalArgumentPrimary - | ifExpression # ifExpressionPrimary - | unlessExpression # unlessExpressionPrimary - | caseExpression # caseExpressionPrimary - | whileExpression # whileExpressionPrimary - | untilExpression # untilExpressionPrimary - | forExpression # forExpressionPrimary - | RETURN argumentsWithParentheses # returnWithParenthesesPrimary - | jumpExpression # jumpExpressionPrimary - | beginExpression # beginExpressionPrimary - | LPAREN compoundStatement RPAREN # groupingExpressionPrimary - | variableReference # variableReferencePrimary - | COLON2 CONSTANT_IDENTIFIER # simpleScopedConstantReferencePrimary - | primary COLON2 CONSTANT_IDENTIFIER # chainedScopedConstantReferencePrimary - | arrayConstructor # arrayConstructorPrimary - | hashConstructor # hashConstructorPrimary - | literal # literalPrimary - | stringExpression # stringExpressionPrimary - | stringInterpolation # stringInterpolationPrimary - | quotedStringExpression # quotedStringExpressionPrimary - | regexInterpolation # regexInterpolationPrimary - | quotedRegexInterpolation # quotedRegexInterpolationPrimary - | IS_DEFINED LPAREN expressionOrCommand RPAREN # isDefinedPrimary - | SUPER argumentsWithParentheses? block? # superExpressionPrimary - | primary LBRACK indexingArguments? RBRACK # indexingExpressionPrimary - | methodOnlyIdentifier # methodOnlyIdentifierPrimary - | methodIdentifier block # invocationWithBlockOnlyPrimary - | methodIdentifier argumentsWithParentheses block? # invocationWithParenthesesPrimary - | primary NL? (DOT | COLON2| AMPDOT) NL? methodName argumentsWithParentheses? block? # chainedInvocationPrimary - | primary COLON2 methodName block? # chainedInvocationWithoutArgumentsPrimary - ; - -// -------------------------------------------------------- -// Assignments -// -------------------------------------------------------- - -singleLeftHandSide - : variableIdentifier # variableIdentifierOnlySingleLeftHandSide - | primary LBRACK arguments? RBRACK # primaryInsideBracketsSingleLeftHandSide - | primary (DOT | COLON2) (LOCAL_VARIABLE_IDENTIFIER | CONSTANT_IDENTIFIER) # xdotySingleLeftHandSide - | COLON2 CONSTANT_IDENTIFIER # scopedConstantAccessSingleLeftHandSide - ; - -multipleLeftHandSide - : (multipleLeftHandSideItem COMMA NL?)+ (multipleLeftHandSideItem | packingLeftHandSide)? # multipleLeftHandSideAndpackingLeftHandSideMultipleLeftHandSide - | packingLeftHandSide # packingLeftHandSideOnlyMultipleLeftHandSide - | groupedLeftHandSide # groupedLeftHandSideOnlyMultipleLeftHandSide - ; - -multipleLeftHandSideItem - : singleLeftHandSide - | groupedLeftHandSide - ; - -packingLeftHandSide - : STAR singleLeftHandSide - ; - -groupedLeftHandSide - : LPAREN multipleLeftHandSide RPAREN - ; - -multipleRightHandSide - : expressionOrCommands (COMMA NL? splattingArgument)? - | splattingArgument - ; - -expressionOrCommands - : expressionOrCommand (COMMA NL? expressionOrCommand)* - ; - -// -------------------------------------------------------- -// Invocation expressions -// -------------------------------------------------------- - -invocationWithoutParentheses - : chainedCommandWithDoBlock # chainedCommandDoBlockInvocationWithoutParentheses - | command # singleCommandOnlyInvocationWithoutParentheses - | RETURN arguments? # returnArgsInvocationWithoutParentheses - | BREAK arguments # breakArgsInvocationWithoutParentheses - | NEXT arguments # nextArgsInvocationWithoutParentheses - ; - -command - : SUPER argumentsWithoutParentheses # superCommand - | YIELD argumentsWithoutParentheses # yieldCommand - | methodIdentifier argumentsWithoutParentheses # simpleMethodCommand - | primary (DOT | COLON2| AMPDOT) NL? methodName argumentsWithoutParentheses # memberAccessCommand - ; - -chainedCommandWithDoBlock - : commandWithDoBlock ((DOT | COLON2) methodName argumentsWithParentheses?)* - ; - -commandWithDoBlock - : SUPER argumentsWithoutParentheses doBlock # argsAndDoBlockCommandWithDoBlock - | methodIdentifier argumentsWithoutParentheses doBlock # argsAndDoBlockAndMethodIdCommandWithDoBlock - | primary (DOT | COLON2) methodName argumentsWithoutParentheses doBlock # primaryMethodArgsDoBlockCommandWithDoBlock - ; - -argumentsWithoutParentheses - : arguments - ; - -arguments - : argument (COMMA NL? argument)* - ; - -argument - : HERE_DOC_IDENTIFIER # hereDocArgument - | blockArgument # blockArgumentArgument - | splattingArgument # splattingArgumentArgument - | association # associationArgument - | expression # expressionArgument - | command # commandArgument - ; - -blockArgument - : AMP expression - ; - -// -------------------------------------------------------- -// Arguments -// -------------------------------------------------------- - -splattingArgument - : STAR expressionOrCommand - | STAR2 expressionOrCommand - ; - -indexingArguments - : expressions (COMMA NL?)? # expressionsOnlyIndexingArguments - | expressions COMMA NL? splattingArgument # expressionsAndSplattingIndexingArguments - | associations (COMMA NL?)? # associationsOnlyIndexingArguments - | splattingArgument # splattingOnlyIndexingArguments - | command # commandOnlyIndexingArguments - ; - -argumentsWithParentheses - : LPAREN NL? RPAREN # blankArgsArgumentsWithParentheses - | LPAREN NL? arguments (COMMA)? NL? RPAREN # argsOnlyArgumentsWithParentheses - | LPAREN NL? expressions COMMA NL? chainedCommandWithDoBlock NL? RPAREN # expressionsAndChainedCommandWithDoBlockArgumentsWithParentheses - | LPAREN NL? chainedCommandWithDoBlock NL? RPAREN # chainedCommandWithDoBlockOnlyArgumentsWithParentheses - ; - -expressions - : expression (COMMA NL? expression)* - ; - -// -------------------------------------------------------- -// Blocks -// -------------------------------------------------------- - -block - : braceBlock # braceBlockBlock - | doBlock # doBlockBlock - ; - -braceBlock - : LCURLY NL? blockParameter? bodyStatement RCURLY - ; - -doBlock - : DO NL? blockParameter? bodyStatement END - ; - -blockParameter - : BAR blockParameters? BAR - ; - -blockParameters - : singleLeftHandSide - | multipleLeftHandSide - ; - -// -------------------------------------------------------- -// Arrays -// -------------------------------------------------------- - -arrayConstructor - : LBRACK NL? indexingArguments? NL? RBRACK # bracketedArrayConstructor - | QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START - nonExpandedArrayElements? - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END # nonExpandedWordArrayConstructor - | QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START - nonExpandedArrayElements? - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END # nonExpandedSymbolArrayConstructor - | QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START - expandedArrayElements? - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END # expandedSymbolArrayConstructor - | QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START - expandedArrayElements? - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END # expandedWordArrayConstructor - ; - - -expandedArrayElements - : EXPANDED_ARRAY_ITEM_SEPARATOR* - expandedArrayElement (EXPANDED_ARRAY_ITEM_SEPARATOR+ expandedArrayElement)* - EXPANDED_ARRAY_ITEM_SEPARATOR* - ; - -expandedArrayElement - : (EXPANDED_ARRAY_ITEM_CHARACTER | delimitedArrayItemInterpolation)+ - ; - -delimitedArrayItemInterpolation - : DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN - compoundStatement - DELIMITED_ARRAY_ITEM_INTERPOLATION_END - ; - -nonExpandedArrayElements - : NON_EXPANDED_ARRAY_ITEM_SEPARATOR* - nonExpandedArrayElement (NON_EXPANDED_ARRAY_ITEM_SEPARATOR+ nonExpandedArrayElement)* - NON_EXPANDED_ARRAY_ITEM_SEPARATOR* - ; - -nonExpandedArrayElement - : NON_EXPANDED_ARRAY_ITEM_CHARACTER+ - ; - -// -------------------------------------------------------- -// Hashes -// -------------------------------------------------------- - -hashConstructor - : LCURLY NL? (hashConstructorElements COMMA?)? NL? RCURLY - ; - -hashConstructorElements - : hashConstructorElement (COMMA NL? hashConstructorElement)* - ; - -hashConstructorElement - : association - | STAR2 expression - ; - -associations - : association (COMMA NL? association)* - ; - -association - : (expression | keyword) (EQGT|COLON) (NL? expression)? - ; - -// -------------------------------------------------------- -// Method definitions -// -------------------------------------------------------- - -methodDefinition - : DEF NL? methodNamePart methodParameterPart bodyStatement END - | DEF NL? methodIdentifier methodParameterPart EQ NL? expression - ; - - -procDefinition - : MINUSGT (LPAREN parameters? RPAREN)? block - ; - -methodNamePart - : definedMethodName # simpleMethodNamePart - | singletonObject NL? (DOT | COLON2) NL? definedMethodName # singletonMethodNamePart - ; - -singletonObject - : variableIdentifier - | pseudoVariableIdentifier - | LPAREN expressionOrCommand RPAREN - ; - -definedMethodName - : methodName - | assignmentLikeMethodIdentifier - ; - -assignmentLikeMethodIdentifier - : ASSIGNMENT_LIKE_METHOD_IDENTIFIER - ; - -methodName - : methodIdentifier - | operatorMethodName - | keyword - ; - -methodIdentifier - : LOCAL_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | methodOnlyIdentifier - ; - -methodOnlyIdentifier - : (LOCAL_VARIABLE_IDENTIFIER | CONSTANT_IDENTIFIER | keyword) (EMARK | QMARK) - ; - -methodParameterPart - : LPAREN NL? parameters? NL? RPAREN - | parameters? - ; - -parameters - : parameter (COMMA NL? parameter)* - ; - -parameter - : optionalParameter - | mandatoryParameter - | arrayParameter - | hashParameter - | keywordParameter - | procParameter - ; - -mandatoryParameter - : LOCAL_VARIABLE_IDENTIFIER - ; - -optionalParameter - : LOCAL_VARIABLE_IDENTIFIER EQ NL? expression - ; - -arrayParameter - : STAR LOCAL_VARIABLE_IDENTIFIER? - ; - -hashParameter - : STAR2 LOCAL_VARIABLE_IDENTIFIER? - ; - -keywordParameter - : LOCAL_VARIABLE_IDENTIFIER COLON (NL? expression)? - ; - -procParameter - : AMP LOCAL_VARIABLE_IDENTIFIER? - ; - - -// -------------------------------------------------------- -// Conditional expressions -// -------------------------------------------------------- - -ifExpression - : IF NL? expressionOrCommand thenClause elsifClause* elseClause? END - ; - -thenClause - : (SEMI | NL)+ compoundStatement - | (SEMI | NL)? THEN compoundStatement - ; - -elsifClause - : ELSIF NL? expressionOrCommand thenClause - ; - -elseClause - : ELSE compoundStatement - ; - -unlessExpression - : UNLESS NL? expressionOrCommand thenClause elseClause? END - ; - -caseExpression - : CASE NL? expressionOrCommand? (SEMI | NL)* whenClause+ elseClause? END - ; - -whenClause - : WHEN NL? whenArgument thenClause - ; - -whenArgument - : expressions (COMMA splattingArgument)? - | splattingArgument - ; - -// -------------------------------------------------------- -// Iteration expressions -// -------------------------------------------------------- - -whileExpression - : WHILE NL? expressionOrCommand doClause END - ; - -doClause - : (SEMI | NL)+ compoundStatement - | DO compoundStatement - ; - -untilExpression - : UNTIL NL? expressionOrCommand doClause END - ; - -forExpression - : FOR NL? forVariable IN NL? expressionOrCommand doClause END - ; - -forVariable - : singleLeftHandSide - | multipleLeftHandSide - ; - -// -------------------------------------------------------- -// Begin expression -// -------------------------------------------------------- - -beginExpression - : BEGIN bodyStatement END - ; - -bodyStatement - : compoundStatement rescueClause* elseClause? ensureClause? - ; - -rescueClause - : RESCUE exceptionClass? NL? exceptionVariableAssignment? thenClause - ; - -exceptionClass - : expression - | multipleRightHandSide - ; - -exceptionVariableAssignment - : EQGT singleLeftHandSide - ; - -ensureClause - : ENSURE compoundStatement - ; - -// -------------------------------------------------------- -// Class definitions -// -------------------------------------------------------- - -classDefinition - : CLASS NL? classOrModuleReference (LT NL? expressionOrCommand)? bodyStatement END - | CLASS NL? LT2 NL? expressionOrCommand (SEMI | NL)+ bodyStatement END - ; - -classOrModuleReference - : scopedConstantReference - | CONSTANT_IDENTIFIER - ; - -// -------------------------------------------------------- -// Module definitions -// -------------------------------------------------------- - -moduleDefinition - : MODULE NL? classOrModuleReference bodyStatement END - ; - -// -------------------------------------------------------- -// Yield expressions -// -------------------------------------------------------- - -yieldWithOptionalArgument - : YIELD (LPAREN arguments? RPAREN)? - ; - -// -------------------------------------------------------- -// Jump expressions -// -------------------------------------------------------- - -jumpExpression - : BREAK - | NEXT - | REDO - | RETRY - ; - -// -------------------------------------------------------- -// Variable references -// -------------------------------------------------------- - -variableReference - : variableIdentifier # variableIdentifierVariableReference - | pseudoVariableIdentifier # pseudoVariableIdentifierVariableReference - ; - -variableIdentifier - : LOCAL_VARIABLE_IDENTIFIER - | GLOBAL_VARIABLE_IDENTIFIER - | INSTANCE_VARIABLE_IDENTIFIER - | CLASS_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - ; - -pseudoVariableIdentifier - : NIL # nilPseudoVariableIdentifier - | TRUE # truePseudoVariableIdentifier - | FALSE # falsePseudoVariableIdentifier - | SELF # selfPseudoVariableIdentifier - | FILE__ # filePseudoVariableIdentifier - | LINE__ # linePseudoVariableIdentifier - | ENCODING__ # encodingPseudoVariableIdentifier - ; - -scopedConstantReference - : COLON2 CONSTANT_IDENTIFIER - | primary COLON2 CONSTANT_IDENTIFIER - ; - -// -------------------------------------------------------- -// Literals -// -------------------------------------------------------- - -literal - : HERE_DOC # hereDocLiteral - | numericLiteral # numericLiteralLiteral - | symbol # symbolLiteral - | REGULAR_EXPRESSION_START REGULAR_EXPRESSION_BODY? REGULAR_EXPRESSION_END # regularExpressionLiteral - ; - -symbol - : SYMBOL_LITERAL - | COLON stringExpression - ; - -// -------------------------------------------------------- -// Strings -// -------------------------------------------------------- - -stringExpression - : simpleString # simpleStringExpression - | stringInterpolation # interpolatedStringExpression - | stringExpression stringExpression+ # concatenatedStringExpression - ; - -quotedStringExpression - : QUOTED_NON_EXPANDED_STRING_LITERAL_START - NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE? - QUOTED_NON_EXPANDED_STRING_LITERAL_END # nonExpandedQuotedStringLiteral - | QUOTED_EXPANDED_STRING_LITERAL_START - (EXPANDED_LITERAL_CHARACTER_SEQUENCE | delimitedStringInterpolation)* - QUOTED_EXPANDED_STRING_LITERAL_END # expandedQuotedStringLiteral - | QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START - (EXPANDED_LITERAL_CHARACTER_SEQUENCE | delimitedStringInterpolation)* - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END # expandedExternalCommandLiteral - ; - -simpleString - : SINGLE_QUOTED_STRING_LITERAL # singleQuotedStringLiteral - | DOUBLE_QUOTED_STRING_START DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE? DOUBLE_QUOTED_STRING_END # doubleQuotedStringLiteral - ; - -delimitedStringInterpolation - : DELIMITED_STRING_INTERPOLATION_BEGIN - compoundStatement - DELIMITED_STRING_INTERPOLATION_END - ; - -stringInterpolation - : DOUBLE_QUOTED_STRING_START - (DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE | interpolatedStringSequence)+ - DOUBLE_QUOTED_STRING_END - ; - -interpolatedStringSequence - : STRING_INTERPOLATION_BEGIN compoundStatement STRING_INTERPOLATION_END - ; - -// -------------------------------------------------------- -// Regex interpolation -// -------------------------------------------------------- - -regexInterpolation - : REGULAR_EXPRESSION_START - (REGULAR_EXPRESSION_BODY | interpolatedRegexSequence)+ - REGULAR_EXPRESSION_END - ; - -interpolatedRegexSequence - : REGULAR_EXPRESSION_INTERPOLATION_BEGIN compoundStatement REGULAR_EXPRESSION_INTERPOLATION_END - ; - -quotedRegexInterpolation - : QUOTED_EXPANDED_REGULAR_EXPRESSION_START - (EXPANDED_LITERAL_CHARACTER_SEQUENCE | delimitedStringInterpolation)* - QUOTED_EXPANDED_REGULAR_EXPRESSION_END - ; - - -// -------------------------------------------------------- -// Numerics -// -------------------------------------------------------- - -numericLiteral - : (PLUS | MINUS)? unsignedNumericLiteral - ; - -unsignedNumericLiteral - : DECIMAL_INTEGER_LITERAL - | BINARY_INTEGER_LITERAL - | OCTAL_INTEGER_LITERAL - | HEXADECIMAL_INTEGER_LITERAL - | FLOAT_LITERAL_WITHOUT_EXPONENT - | FLOAT_LITERAL_WITH_EXPONENT - ; - -// -------------------------------------------------------- -// Helpers -// -------------------------------------------------------- - -definedMethodNameOrSymbol - : definedMethodName - | symbol - ; - -keyword - : LINE__ - | ENCODING__ - | FILE__ - | BEGIN_ - | END_ - | ALIAS - | AND - | BEGIN - | BREAK - | CASE - | CLASS - | DEF - | IS_DEFINED - | DO - | ELSE - | ELSIF - | END - | ENSURE - | FOR - | FALSE - | IF - | IN - | MODULE - | NEXT - | NIL - | NOT - | OR - | REDO - | RESCUE - | RETRY - | RETURN - | SELF - | SUPER - | THEN - | TRUE - | UNDEF - | UNLESS - | UNTIL - | WHEN - | WHILE - | YIELD - ; - -operatorMethodName - : CARET - | AMP - | BAR - | LTEQGT - | EQ2 - | EQ3 - | EQTILDE - | GT - | GTEQ - | LT - | LTEQ - | LT2 - | GT2 - | PLUS - | MINUS - | STAR - | SLASH - | PERCENT - | STAR2 - | TILDE - | PLUSAT - | MINUSAT - | LBRACK RBRACK - | LBRACK RBRACK EQ - ; diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyLexer.g4 b/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyLexer.g4 deleted file mode 100644 index 321ad6097bfe..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyLexer.g4 +++ /dev/null @@ -1,923 +0,0 @@ -lexer grammar RubyLexer; - -// -------------------------------------------------------- -// Auxiliary tokens and features -// -------------------------------------------------------- - -@header { - package io.joern.rubysrc2cpg.parser; -} - -tokens { - STRING_INTERPOLATION_END, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_START, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - QUOTED_EXPANDED_STRING_LITERAL_END, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - DELIMITED_STRING_INTERPOLATION_END, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - - // The following tokens are created by `RubyLexerPostProcessor` only. - NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE, - EXPANDED_LITERAL_CHARACTER_SEQUENCE -} - -options { - superClass = RubyLexerBase; -} - -// -------------------------------------------------------- -// Keywords -// -------------------------------------------------------- - -LINE__:'__LINE__'; -ENCODING__: '__ENCODING__'; -FILE__: '__FILE__'; -BEGIN_: 'BEGIN'; -END_: 'END'; -ALIAS: 'alias'; -AND: 'and'; -BEGIN: 'begin'; -BREAK: 'break'; -CASE: 'case'; -CLASS: 'class'; -DEF: 'def'; -IS_DEFINED: 'defined?'; -DO: 'do'; -ELSE: 'else'; -ELSIF: 'elsif'; -END: 'end'; -ENSURE: 'ensure'; -FOR: 'for'; -FALSE: 'false'; -IF: 'if'; -IN: 'in'; -MODULE: 'module'; -NEXT: 'next'; -NIL: 'nil'; -NOT: 'not'; -OR: 'or'; -REDO: 'redo'; -RESCUE: 'rescue'; -RETRY: 'retry'; -RETURN: 'return'; -SELF: 'self'; -SUPER: 'super'; -THEN: 'then'; -TRUE: 'true'; -UNDEF: 'undef'; -UNLESS: 'unless'; -UNTIL: 'until'; -WHEN: 'when'; -WHILE: 'while'; -YIELD: 'yield'; - -fragment KEYWORD - : LINE__ - | ENCODING__ - | FILE__ - | BEGIN_ - | END_ - | ALIAS - | AND - | BEGIN - | BREAK - | CASE - | CLASS - | DEF - | IS_DEFINED - | DO - | ELSE - | ELSIF - | END - | ENSURE - | FOR - | FALSE - | IF - | IN - | MODULE - | NEXT - | NIL - | NOT - | OR - | REDO - | RESCUE - | RETRY - | RETURN - | SELF - | SUPER - | THEN - | TRUE - | UNDEF - | UNLESS - | UNTIL - | WHEN - | WHILE - | YIELD - ; - -// -------------------------------------------------------- -// Punctuators -// -------------------------------------------------------- - -LBRACK: '['; -RBRACK: ']'; -LPAREN: '('; -RPAREN: ')'; -LCURLY: '{'; -RCURLY: '}' - { - if (isEndOfInterpolation()) { - popMode(); - setType(popInterpolationEndTokenType()); - } - } -; -COLON: ':'; -COLON2: '::'; -COMMA: ','; -SEMI: ';'; -DOT: '.'; -DOT2: '..'; -DOT3: '...'; -QMARK: '?'; -EQGT: '=>'; -MINUSGT: '->'; - -fragment PUNCTUATOR - : LBRACK - | RBRACK - | LPAREN - | RPAREN - | LCURLY - | RCURLY - | COLON2 - | COMMA - | SEMI - | DOT2 - | DOT3 - | QMARK - | COLON - | EQGT - ; - -// -------------------------------------------------------- -// Operators -// -------------------------------------------------------- - -EMARK: '!'; -EMARKEQ: '!='; -EMARKTILDE: '!~'; -AMP: '&'; -AMP2: '&&'; -AMPDOT: '&.'; -BAR: '|'; -BAR2: '||'; -EQ: '='; -EQ2: '=='; -EQ3: '==='; -CARET: '^'; -LTEQGT: '<=>'; -EQTILDE: '=~'; -GT: '>'; -GTEQ: '>='; -LT: '<'; -LTEQ: '<='; -LT2: '<<'; -GT2: '>>'; -PLUS: '+'; -MINUS: '-'; -STAR: '*'; -STAR2: '**'; -SLASH: '/' - { - if (isStartOfRegexLiteral()) { - setType(REGULAR_EXPRESSION_START); - pushMode(REGULAR_EXPRESSION_MODE); - } - } -; -PERCENT: '%'; -TILDE: '~'; -// These tokens should only occur after a DEF token, as they are solely used to (re)define unary + and - operators. -// This way we won't emit the wrong token in e.g. `x+@y` (which means + between x and @y) -PLUSAT: '+@' {previousNonWsTokenTypeOrEOF() == DEF}?; -MINUSAT: '-@' {previousNonWsTokenTypeOrEOF() == DEF}?; - -ASSIGNMENT_OPERATOR - : ASSIGNMENT_OPERATOR_NAME '=' - ; - -fragment ASSIGNMENT_OPERATOR_NAME - : AMP - | AMP2 - | BAR - | BAR2 - | CARET - | LT2 - | GT2 - | PLUS - | MINUS - | STAR - | STAR2 - | PERCENT - | SLASH - ; - -fragment OPERATOR_METHOD_NAME - : CARET - | AMP - | BAR - | LTEQGT - | EQ2 - | EQ3 - | EQTILDE - | GT - | GTEQ - | LT - | LTEQ - | LT2 - | GT2 - | PLUS - | MINUS - | STAR - | SLASH - | PERCENT - | STAR2 - | TILDE - | PLUSAT - | MINUSAT - | '[]' - | '[]=' - ; - -// -------------------------------------------------------- -// String literals -// -------------------------------------------------------- - -SINGLE_QUOTED_STRING_LITERAL - : '\'' SINGLE_QUOTED_STRING_CHARACTER*? '\'' - ; - -fragment SINGLE_QUOTED_STRING_CHARACTER - : SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - | SINGLE_QUOTED_ESCAPE_SEQUENCE - ; - -fragment SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - : ~['\\] - ; - -fragment SINGLE_QUOTED_ESCAPE_SEQUENCE - : SINGLE_ESCAPE_CHARACTER_SEQUENCE - | SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER_SEQUENCE - ; - -fragment SINGLE_ESCAPE_CHARACTER_SEQUENCE - : '\\' SINGLE_QUOTED_STRING_META_CHARACTER - ; - -fragment SINGLE_QUOTED_STRING_META_CHARACTER - : ['\\] - ; - -fragment SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER_SEQUENCE - : '\\' SINGLE_QUOTED_STRING_NON_ESCAPED_CHARACTER - ; - -DOUBLE_QUOTED_STRING_START - : '"' - -> pushMode(DOUBLE_QUOTED_STRING_MODE) - ; - -QUOTED_NON_EXPANDED_STRING_LITERAL_START - : '%q' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_STRING_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_STRING_MODE) - ; - -QUOTED_EXPANDED_STRING_LITERAL_START - : '%Q' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_LITERAL_END); - _input.consume(); - pushMode(EXPANDED_DELIMITED_STRING_MODE); - } - // This check exists to prevent issuing a QUOTED_EXPANDED_STRING_LITERAL_START - // in obvious arithmetic expressions, such as `20 %(x+1)`. - // Note, however, that we can't have a perfect test at this stage. For instance, - // in `x = 1; x %(2)`, it's clear that's an arithmetic expression, but we - // will still emit a QUOTED_EXPANDED_STRING_LITERAL_START. - | '%(' {!isNumericTokenType(previousTokenTypeOrEOF())}? - { - pushQuotedDelimiter('('); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_LITERAL_END); - pushMode(EXPANDED_DELIMITED_STRING_MODE); - } - ; - -QUOTED_EXPANDED_REGULAR_EXPRESSION_START - : '%r' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_REGULAR_EXPRESSION_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_STRING_MODE) - ; - -QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START - : '%x' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_STRING_MODE) - ; - -// -------------------------------------------------------- -// String (Word) array literals -// -------------------------------------------------------- - -QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START - : '%w' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_ARRAY_MODE) - ; - -QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START - : '%W' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_ARRAY_MODE) - ; - -// -------------------------------------------------------- -// Here doc literals -// -------------------------------------------------------- - -HERE_DOC_IDENTIFIER - : '<<' [-~]? [\t]* IDENTIFIER - ; - -HERE_DOC - : '<<' [-~]? [\t]* IDENTIFIER [a-zA-Z_0-9]* NL WS* ( {!heredocEndAhead(getText())}? . )* NL? WS* [a-zA-Z_] [a-zA-Z_0-9]* - ; - -// -------------------------------------------------------- -// Symbol array literals -// -------------------------------------------------------- - -QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START - : '%i' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(NON_EXPANDED_DELIMITED_ARRAY_MODE) - ; - -QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START - : '%I' {!Character.isAlphabetic(_input.LA(1))}? - { - pushQuotedDelimiter(_input.LA(1)); - pushQuotedEndTokenType(QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END); - _input.consume(); - } - -> pushMode(EXPANDED_DELIMITED_ARRAY_MODE) - ; - -// -------------------------------------------------------- -// Data section -// -------------------------------------------------------- - -END_OF_PROGRAM_MARKER - : '__END__' {getCharPositionInLine() == 7}? '\r'? '\n' - -> pushMode(DATA_SECTION_MODE), skip - ; - -// -------------------------------------------------------- -// Numeric literals -// -------------------------------------------------------- - -DECIMAL_INTEGER_LITERAL - : UNPREFIXED_DECIMAL_INTEGER_LITERAL - | PREFIXED_DECIMAL_INTEGER_LITERAL - ; - -BINARY_INTEGER_LITERAL - : '0' [bB] BINARY_DIGIT ('_'? BINARY_DIGIT)* - ; - -OCTAL_INTEGER_LITERAL - : '0' [_oO]? OCTAL_DIGIT ('_'? OCTAL_DIGIT)* - ; - -HEXADECIMAL_INTEGER_LITERAL - : '0' [xX] HEXADECIMAL_DIGIT ('_'? HEXADECIMAL_DIGIT)* - ; - -FLOAT_LITERAL_WITHOUT_EXPONENT - : UNPREFIXED_DECIMAL_INTEGER_LITERAL '.' DIGIT_DECIMAL_PART - ; - -FLOAT_LITERAL_WITH_EXPONENT - : SIGNIFICAND_PART EXPONENT_PART - ; - -fragment UNPREFIXED_DECIMAL_INTEGER_LITERAL - : '0' - | DECIMAL_DIGIT_EXCEPT_0 ('_'? DECIMAL_DIGIT)* - ; - -fragment PREFIXED_DECIMAL_INTEGER_LITERAL - : '0' [dD] DIGIT_DECIMAL_PART - ; - -fragment SIGNIFICAND_PART - : FLOAT_LITERAL_WITHOUT_EXPONENT - | UNPREFIXED_DECIMAL_INTEGER_LITERAL - ; - -fragment EXPONENT_PART - : [eE] ('+' | '-')? DIGIT_DECIMAL_PART - ; - -fragment BINARY_DIGIT - : [0-1] - ; - -fragment OCTAL_DIGIT - : [0-7] - ; - -fragment DIGIT_DECIMAL_PART - : DECIMAL_DIGIT ('_'? DECIMAL_DIGIT)* - ; - -fragment DECIMAL_DIGIT - : [0-9] - ; - -fragment DECIMAL_DIGIT_EXCEPT_0 - : [1-9] - ; - -fragment HEXADECIMAL_DIGIT - : DECIMAL_DIGIT - | [a-f] - | [A-F] - ; - -// -------------------------------------------------------- -// Whitespaces -// -------------------------------------------------------- - -NL: LINE_TERMINATOR+; -WS: WHITESPACE+ -> channel(HIDDEN); - -fragment WHITESPACE - : [\u0009] - | [\u000b] - | [\u000c] - | [\u000d] - | [\u0020] - | LINE_TERMINATOR_ESCAPE_SEQUENCE - ; - -fragment LINE_TERMINATOR_ESCAPE_SEQUENCE - : '\\' LINE_TERMINATOR - ; - -fragment LINE_TERMINATOR - : '\r'? '\n' - ; - -// -------------------------------------------------------- -// Symbols -// -------------------------------------------------------- - -SYMBOL_LITERAL - : ':' (SYMBOL_NAME | (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) '=') - // This check exists to prevent issuing a SYMBOL_LITERAL in whitespace-free associations, e.g. - // in `foo(x:y)`, so that `:y` is not a SYMBOL_LITERAL - // or in `{:x=>1}`, so that `:x=` is not a SYMBOL_LITERAL - {previousTokenTypeOrEOF() != LOCAL_VARIABLE_IDENTIFIER && _input.LA(1) != '>'}? - ; - -fragment SYMBOL_NAME - : INSTANCE_VARIABLE_IDENTIFIER - | GLOBAL_VARIABLE_IDENTIFIER - | CLASS_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | LOCAL_VARIABLE_IDENTIFIER - | METHOD_ONLY_IDENTIFIER - | OPERATOR_METHOD_NAME - | KEYWORD - // NOTE: Even though we have PLUSAT and MINUSAT in OPERATOR_METHOD_NAME, the former - // are not emitted unless there's a DEF token before them, cf. their predicate. - // Thus, we need to add them explicitly here in order to recognize standalone SYMBOL_LITERAL tokens as well. - | '+@' - | '-@' - ; - -// -------------------------------------------------------- -// Identifiers -// -------------------------------------------------------- - -LOCAL_VARIABLE_IDENTIFIER - : (LOWERCASE_CHARACTER | '_') IDENTIFIER_CHARACTER* - ; - -GLOBAL_VARIABLE_IDENTIFIER - : '$' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - | '$' [0-9]+ - | '$!' - | '$@' - | '$~' - | '$&' - | '$`' - | '$\'' - | '$+' - | '$=' - | '$/' - | '$\\' - | '$,' - | '$;' - | '$.' - | '$:' - | '$<' - | '$>' - | '$_' - | '$0' - | '$*' - | '$$' - | '$?' - | '$-a' - | '$-i' - | '$-l' - | '$-p' - ; - -INSTANCE_VARIABLE_IDENTIFIER - : '@' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - ; - -CLASS_VARIABLE_IDENTIFIER - : '@@' IDENTIFIER_START_CHARACTER IDENTIFIER_CHARACTER* - ; - -CONSTANT_IDENTIFIER - : UPPERCASE_CHARACTER IDENTIFIER_CHARACTER* - ; - -fragment METHOD_ONLY_IDENTIFIER - : (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) (EMARK | QMARK | EQ) - ; - - -// Similarly to PLUSAT/MINUSAT, this should only occur after a DEF token. -// Otherwise, the assignment `x=nil` would be parsed as (ASSIGNMENT_LIKE_METHOD_IDENTIFIER, NIL) -// instead of the more appropriate (LOCAL_VARIABLE_IDENTIFIER, EQ, NIL). -ASSIGNMENT_LIKE_METHOD_IDENTIFIER - : (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) '=' {previousNonWsTokenTypeOrEOF() == DEF}? - ; - -fragment IDENTIFIER_CHARACTER - : IDENTIFIER_START_CHARACTER - | DECIMAL_DIGIT - | '_' - ; - -fragment IDENTIFIER_START_CHARACTER - : LOWERCASE_CHARACTER - | UPPERCASE_CHARACTER - | '_' - ; - -fragment LOWERCASE_CHARACTER - : [a-z] - ; - -fragment UPPERCASE_CHARACTER - : [A-Z] - ; - -fragment IDENTIFIER - : LOCAL_VARIABLE_IDENTIFIER - | GLOBAL_VARIABLE_IDENTIFIER - | CLASS_VARIABLE_IDENTIFIER - | INSTANCE_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | METHOD_ONLY_IDENTIFIER - | ASSIGNMENT_LIKE_METHOD_IDENTIFIER - ; - -// -------------------------------------------------------- -// Comments (are skipped) -// -------------------------------------------------------- - -SINGLE_LINE_COMMENT - : '#' COMMENT_CONTENT? - -> skip; - -MULTI_LINE_COMMENT - : MULTI_LINE_COMMENT_BEGIN_LINE .*? MULTI_LINE_COMMENT_END_LINE - -> skip; - -fragment COMMENT_CONTENT - : (~[\r\n])+ // Meaning (~LINE_TERMINATOR)+ - ; - -fragment MULTI_LINE_COMMENT_BEGIN_LINE - : '=begin' {getCharPositionInLine() == 6}? REST_OF_BEGIN_END_LINE? LINE_TERMINATOR - ; - -fragment MULTI_LINE_COMMENT_END_LINE - : '=end' {getCharPositionInLine() == 4}? REST_OF_BEGIN_END_LINE? (LINE_TERMINATOR | EOF) - ; - -fragment REST_OF_BEGIN_END_LINE - : WHITESPACE+ COMMENT_CONTENT - ; - -// -------------------------------------------------------- -// Unrecognized characters -// -------------------------------------------------------- - -// Any other character shall still be recognized so that the -// recovery mechanism in `io.joern.rubysrc2cpg.astcreation.AntlrParser` -// also handles them. Otherwise, the lexer would complain, not emit -// and the recovery mechanism would not be able to act. - -// Note: this must be the very last rule in this lexer specification, as -// otherwise this token would take precedence over any token defined after. -UNRECOGNIZED - : . - ; - -// -------------------------------------------------------- -// Double quoted string mode -// -------------------------------------------------------- - -mode DOUBLE_QUOTED_STRING_MODE; - -DOUBLE_QUOTED_STRING_END - : '"' - -> popMode - ; - -DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE - : DOUBLE_QUOTED_STRING_CHARACTER+ - ; - -fragment INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - : '#' GLOBAL_VARIABLE_IDENTIFIER - | '#' CLASS_VARIABLE_IDENTIFIER - | '#' INSTANCE_VARIABLE_IDENTIFIER - ; - -INTERPOLATED_CHARACTER_SEQUENCE - : INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - ; - -STRING_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(STRING_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -fragment DOUBLE_QUOTED_STRING_CHARACTER - : ~["#\\] - | '#' {_input.LA(1) != '$' && _input.LA(1) != '@' && _input.LA(1) != '{'}? - | DOUBLE_ESCAPE_SEQUENCE - ; - -fragment DOUBLE_ESCAPE_SEQUENCE - : SIMPLE_ESCAPE_SEQUENCE - | NON_ESCAPED_SEQUENCE - | LINE_TERMINATOR_ESCAPE_SEQUENCE - | OCTAL_ESCAPE_SEQUENCE - | HEXADECIMAL_ESCAPE_SEQUENCE - | CONTROL_ESCAPE_SEQUENCE - ; - -fragment CONTROL_ESCAPE_SEQUENCE - : '\\' ('C-' | 'c') CONTROL_ESCAPED_CHARACTER - ; - -fragment CONTROL_ESCAPED_CHARACTER - : DOUBLE_ESCAPE_SEQUENCE - | '?' - | ~[\\?] - ; - -fragment OCTAL_ESCAPE_SEQUENCE - : '\\' OCTAL_DIGIT OCTAL_DIGIT? OCTAL_DIGIT? - ; - -fragment HEXADECIMAL_ESCAPE_SEQUENCE - : '\\x' HEXADECIMAL_DIGIT HEXADECIMAL_DIGIT? - ; - -fragment NON_ESCAPED_SEQUENCE - : '\\' NON_ESCAPED_DOUBLE_QUOTED_STRING_CHARACTER - ; - -fragment NON_ESCAPED_DOUBLE_QUOTED_STRING_CHARACTER - : ~[\r\nA-Za-z0-9] - ; - -fragment SIMPLE_ESCAPE_SEQUENCE - : '\\' DOUBLE_ESCAPED_CHARACTER - ; - -fragment DOUBLE_ESCAPED_CHARACTER - : [ntrfvaebsu] - ; - -// -------------------------------------------------------- -// Expanded delimited string mode -// -------------------------------------------------------- - -mode EXPANDED_DELIMITED_STRING_MODE; - -DELIMITED_STRING_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(DELIMITED_STRING_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -EXPANDED_VARIABLE_CHARACTER_SEQUENCE - : INTERPOLATED_CHARACTER_SEQUENCE_FRAGMENT - ; - -EXPANDED_LITERAL_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Non-expanded delimited string mode -// -------------------------------------------------------- - -mode NON_EXPANDED_DELIMITED_STRING_MODE; - - -fragment NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - : '\\' NON_ESCAPED_LITERAL_CHARACTER - ; - -fragment NON_ESCAPED_LITERAL_CHARACTER - : ~[\r\n] - | '\n' {_input.LA(1) != '\r'}? - ; - -NON_EXPANDED_LITERAL_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Expanded delimited array mode -// -------------------------------------------------------- - -mode EXPANDED_DELIMITED_ARRAY_MODE; - -DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(DELIMITED_ARRAY_ITEM_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -EXPANDED_ARRAY_ITEM_SEPARATOR - : NON_EXPANDED_ARRAY_ITEM_DELIMITER - ; - -EXPANDED_ARRAY_ITEM_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Non-expanded delimited array mode -// -------------------------------------------------------- - -mode NON_EXPANDED_DELIMITED_ARRAY_MODE; - -fragment NON_EXPANDED_ARRAY_ITEM_DELIMITER - : [\u0009] - | [\u000a] - | [\u000b] - | [\u000c] - | [\u000d] - | [\u0020] - | '\\' ('\r'? '\n') - ; - -NON_EXPANDED_ARRAY_ITEM_SEPARATOR - : NON_EXPANDED_ARRAY_ITEM_DELIMITER - ; - -NON_EXPANDED_ARRAY_ITEM_CHARACTER - : NON_EXPANDED_LITERAL_ESCAPE_SEQUENCE - | NON_ESCAPED_LITERAL_CHARACTER - { - consumeQuotedCharAndMaybePopMode(_input.LA(-1)); - } - ; - -// -------------------------------------------------------- -// Regex literal mode -// -------------------------------------------------------- - -mode REGULAR_EXPRESSION_MODE; - -REGULAR_EXPRESSION_END - : '/' REGULAR_EXPRESSION_OPTION* - -> popMode - ; - -REGULAR_EXPRESSION_BODY - : REGULAR_EXPRESSION_CHARACTER+ - ; - -REGULAR_EXPRESSION_INTERPOLATION_BEGIN - : '#{' - { - pushInterpolationEndTokenType(REGULAR_EXPRESSION_INTERPOLATION_END); - pushMode(DEFAULT_MODE); - } - ; - -fragment REGULAR_EXPRESSION_OPTION - : [imxo] - ; - -fragment REGULAR_EXPRESSION_CHARACTER - : ~[/#\\] - | '#' {_input.LA(1) != '$' && _input.LA(1) != '@' && _input.LA(1) != '{'}? - | REGULAR_EXPRESSION_NON_ESCAPED_SEQUENCE - | REGULAR_EXPRESSION_ESCAPE_SEQUENCE - | LINE_TERMINATOR_ESCAPE_SEQUENCE - | INTERPOLATED_CHARACTER_SEQUENCE - ; - -fragment REGULAR_EXPRESSION_NON_ESCAPED_SEQUENCE - : '\\' REGULAR_EXPRESSION_NON_ESCAPED_CHARACTER - ; - -fragment REGULAR_EXPRESSION_NON_ESCAPED_CHARACTER - : ~[\r\n] - | '\n' {_input.LA(1) != '\r'}? - ; - -fragment REGULAR_EXPRESSION_ESCAPE_SEQUENCE - : '\\' '/' - ; - -// -------------------------------------------------------- -// Data section mode -// -------------------------------------------------------- - -mode DATA_SECTION_MODE; - -DATA_SECTION_CONTENT - : .*? EOF - -> popMode, skip - ; diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyParser.g4 b/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyParser.g4 deleted file mode 100644 index 986106504530..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/antlr4/io/joern/rubysrc2cpg/parser/RubyParser.g4 +++ /dev/null @@ -1,832 +0,0 @@ -parser grammar RubyParser; - -@header { - package io.joern.rubysrc2cpg.parser; -} - -options { - tokenVocab = RubyLexer; -} - -// -------------------------------------------------------- -// Program -// -------------------------------------------------------- - -program - : compoundStatement EOF - ; - -compoundStatement - : statements? (SEMI | NL)* - ; - -statements - : (SEMI | NL)* statement ((SEMI | NL)+ statement)* - ; - -statement - : expressionOrCommand - # expressionOrCommandStatement - | ALIAS NL* oldName=definedMethodNameOrSymbol newName=definedMethodNameOrSymbol - # aliasStatement - | UNDEF NL* definedMethodNameOrSymbol (COMMA NL* definedMethodNameOrSymbol)* - # undefStatement - | statement statementModifier NL* expressionOrCommand - # modifierStatement - | singleAssignmentStatement - # singleAssignmentStatementStatement - | multipleAssignmentStatement - # multipleAssignmentStatementStatement - ; - -definedMethodNameOrSymbol - : definedMethodName - | symbol - ; - -singleAssignmentStatement - : variable assignmentOperator NL* methodInvocationWithoutParentheses - | COLON2 CONSTANT_IDENTIFIER assignmentOperator NL* methodInvocationWithoutParentheses - | primary LBRACK indexingArgumentList? RBRACK assignmentOperator NL* methodInvocationWithoutParentheses - | primary (DOT | COLON2) methodName assignmentOperator NL* methodInvocationWithoutParentheses - ; - -multipleAssignmentStatement - : leftHandSide EQ NL* multipleRightHandSide - | packingLeftHandSide EQ NL* (methodInvocationWithoutParentheses | operatorExpression) - | multipleLeftHandSide EQ NL* multipleRightHandSide - | multipleLeftHandSideExceptPacking EQ NL* (methodInvocationWithoutParentheses | operatorExpression) - ; - -leftHandSide - : variable (EQ primary)? - # variableLeftHandSide - | primary LBRACK indexingArgumentList? RBRACK - # indexingLeftHandSide - | primary (DOT | COLON2) (LOCAL_VARIABLE_IDENTIFIER | CONSTANT_IDENTIFIER) - # memberAccessLeftHandSide - | COLON2 CONSTANT_IDENTIFIER - # qualifiedLeftHandSide - ; - -multipleLeftHandSide - : (multipleLeftHandSideItem COMMA)+ multipleLeftHandSideItem? - | (multipleLeftHandSideItem COMMA)+ packingLeftHandSide? (COMMA? NL* procParameter)? COMMA? - | packingLeftHandSide - | groupedLeftHandSide - ; - -multipleLeftHandSideExceptPacking - : (multipleLeftHandSideItem COMMA)+ multipleLeftHandSideItem? - | (multipleLeftHandSideItem COMMA)+ packingLeftHandSide? - | groupedLeftHandSide - ; - -packingLeftHandSide - : STAR leftHandSide? - | STAR leftHandSide (COMMA multipleLeftHandSideItem)* - ; - -groupedLeftHandSide - : LPAREN multipleLeftHandSide RPAREN - ; - -multipleLeftHandSideItem - : leftHandSide - | groupedLeftHandSide - ; - -multipleRightHandSide - : operatorExpressionList (COMMA splattingRightHandSide)? - | splattingRightHandSide - ; - -splattingRightHandSide - : splattingArgument - ; - -// -------------------------------------------------------- -// Method invocation expressions -// -------------------------------------------------------- - -methodIdentifier - : LOCAL_VARIABLE_IDENTIFIER - | CONSTANT_IDENTIFIER - | methodOnlyIdentifier - ; - -methodName - : methodIdentifier - | keyword - | pseudoVariable - ; - -methodOnlyIdentifier - : (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER | pseudoVariable) (EMARK | QMARK | EQ) - ; - -methodInvocationWithoutParentheses - : command - # commandMethodInvocationWithoutParentheses - | chainedCommandWithDoBlock ((DOT | COLON2) methodName commandArgumentList)? - # chainedMethodInvocationWithoutParentheses - | RETURN primaryValueList - # returnMethodInvocationWithoutParentheses - | BREAK primaryValueList - # breakMethodInvocationWithoutParentheses - | NEXT primaryValueList - # nextMethodInvocationWithoutParentheses - | YIELD primaryValueList - # yieldMethodInvocationWithoutParentheses - ; - -command - : primary NL? (AMPDOT | DOT | COLON2) methodName commandArgument - # memberAccessCommand - | methodIdentifier commandArgument - # simpleCommand - ; - -commandArgument - : commandArgumentList - # commandArgumentCommandArgumentList - | command - # commandCommandArgumentList - ; - -chainedCommandWithDoBlock - : commandWithDoBlock chainedMethodInvocation* - ; - -chainedMethodInvocation - : (DOT | COLON2) methodName argumentWithParentheses? - ; - -commandWithDoBlock - : SUPER argumentList doBlock - | methodIdentifier argumentList doBlock - | primary (DOT | COLON2) methodName argumentList doBlock - ; - -indexingArgumentList - : command - # commandIndexingArgumentList - | operatorExpressionList COMMA? - # operatorExpressionListIndexingArgumentList - | operatorExpressionList COMMA splattingArgument - # operatorExpressionListWithSplattingArgumentIndexingArgumentList - | associationList COMMA? - # associationListIndexingArgumentList - | splattingArgument - # splattingArgumentIndexingArgumentList - ; - -splattingArgument - : STAR operatorExpression - | STAR2 operatorExpression - ; - -operatorExpressionList - : operatorExpression (COMMA NL* operatorExpression)* - ; - -operatorExpressionList2 - : operatorExpression (COMMA NL* operatorExpression)+ - ; - -argumentWithParentheses - : LPAREN NL* COMMA? NL* RPAREN - # emptyArgumentWithParentheses - | LPAREN NL* argumentList COMMA? NL* RPAREN - # argumentListArgumentWithParentheses - | LPAREN NL* operatorExpressionList COMMA NL* chainedCommandWithDoBlock COMMA? NL* RPAREN - # operatorExpressionsAndChainedCommandWithBlockArgumentWithParentheses - | LPAREN NL* chainedCommandWithDoBlock COMMA? NL* RPAREN - # chainedCommandWithDoBlockArgumentWithParentheses - ; - -argumentList - : blockArgument - # blockArgumentArgumentList - | splattingArgument (COMMA NL* blockArgument)? - # splattingArgumentArgumentList - | operatorExpressionList (COMMA NL* associationList)? (COMMA NL* splattingArgument)? (COMMA NL* blockArgument)? - # operatorsArgumentList - | associationList (COMMA NL* splattingArgument)? (COMMA NL* blockArgument)? - # associationsArgumentList - | command - # singleCommandArgumentList - ; - -commandArgumentList - : associationList - | primaryValueList (COMMA NL* associationList)? - ; - -primaryValueList - : primaryValue (COMMA NL* primaryValue)* - ; - -blockArgument - : AMP operatorExpression - ; - -// -------------------------------------------------------- -// Expressions -// -------------------------------------------------------- - -expressionOrCommand - : operatorExpression - # operatorExpressionOrCommand - | EMARK? methodInvocationWithoutParentheses - # commandExpressionOrCommand - | NOT NL* expressionOrCommand - # notExpressionOrCommand - | lhs=expressionOrCommand binOp=(AND|OR) NL* rhs=expressionOrCommand - # keywordAndOrExpressionOrCommand - ; - -operatorExpression - : primary - # primaryOperatorExpression - | operatorExpression QMARK NL* operatorExpression NL* COLON NL* operatorExpression - # ternaryOperatorExpression - ; - -primary - : RETURN - # returnWithoutArguments - | BREAK - # breakWithoutArguments - | NEXT - # nextWithoutArguments - | REDO - # redoWithoutArguments - | RETRY - # retryWithoutArguments - | primaryValue - # primaryValuePrimary - ; - -primaryValue - : // Assignment expressions - lhs=variable assignmentOperator NL* rhs=operatorExpression - # localVariableAssignmentExpression - | primaryValue op=(DOT | COLON2) methodName assignmentOperator NL* operatorExpression - # attributeAssignmentExpression - | COLON2 CONSTANT_IDENTIFIER assignmentOperator NL* operatorExpression - # constantAssignmentExpression - | primaryValue LBRACK indexingArgumentList? RBRACK assignmentOperator NL* operatorExpression - # bracketAssignmentExpression - | primaryValue assignmentOperator NL* operatorExpression RESCUE operatorExpression - # assignmentWithRescue - - // Definitions - | CLASS classPath (LT commandOrPrimaryValueClass)? (SEMI | NL) bodyStatement END - # classDefinition - | CLASS LT2 commandOrPrimaryValueClass (SEMI | NL) bodyStatement END - # singletonClassDefinition - | MODULE classPath bodyStatement END - # moduleDefinition - | DEF definedMethodName methodParameterPart bodyStatement END - # methodDefinition - | DEF singletonObject op=(DOT | COLON2) definedMethodName methodParameterPart bodyStatement END - # singletonMethodDefinition - | DEF definedMethodName (LPAREN parameterList? RPAREN)? EQ NL* statement - # endlessMethodDefinition - | MINUSGT (LPAREN parameterList? RPAREN)? block - # lambdaExpression - - // Control structures - | IF NL* expressionOrCommand thenClause elsifClause* elseClause? END - # ifExpression - | UNLESS NL* expressionOrCommand thenClause elseClause? END - # unlessExpression - | UNTIL NL* expressionOrCommand doClause END - # untilExpression - | YIELD argumentWithParentheses? - # yieldExpression - | BEGIN bodyStatement END - # beginEndExpression - | CASE NL* expressionOrCommand (SEMI | NL)* whenClause+ elseClause? END - # caseWithExpression - | CASE (SEMI | NL)* whenClause+ elseClause? END - # caseWithoutExpression - | WHILE NL* expressionOrCommand doClause END - # whileExpression - | FOR NL* forVariable IN NL* commandOrPrimaryValue doClause END - # forExpression - - // Non-nested calls - | SUPER argumentWithParentheses? block? - # superWithParentheses - | SUPER argumentList? block? - # superWithoutParentheses - | isDefinedKeyword LPAREN expressionOrCommand RPAREN - # isDefinedExpression - | isDefinedKeyword primaryValue - # isDefinedCommand - | methodOnlyIdentifier - # methodCallExpression - | methodIdentifier block - # methodCallWithBlockExpression - | methodIdentifier argumentWithParentheses block? - # methodCallWithParenthesesExpression - | variableReference - # methodCallOrVariableReference - - // Literals - | LBRACK NL* indexingArgumentList? NL* RBRACK - # bracketedArrayLiteral - | QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START quotedNonExpandedArrayElementList? QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END - # quotedNonExpandedStringArrayLiteral - | QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START quotedNonExpandedArrayElementList? QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END - # quotedNonExpandedSymbolArrayLiteral - | QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START quotedExpandedArrayElementList? QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END - # quotedExpandedStringArrayLiteral - | QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START quotedExpandedArrayElementList? QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END - # quotedExpandedSymbolArrayLiteral - | LCURLY NL* (associationList COMMA?)? NL* RCURLY - # hashLiteral - | sign=(PLUS | MINUS)? unsignedNumericLiteral - # numericLiteral - | singleQuotedString singleOrDoubleQuotedString* - # singleQuotedStringExpression - | doubleQuotedString singleOrDoubleQuotedString* - # doubleQuotedStringExpression - | QUOTED_NON_EXPANDED_STRING_LITERAL_START NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE? QUOTED_NON_EXPANDED_STRING_LITERAL_END - # quotedNonExpandedStringLiteral - | QUOTED_EXPANDED_STRING_LITERAL_START quotedExpandedLiteralStringContent* QUOTED_EXPANDED_STRING_LITERAL_END - # quotedExpandedStringLiteral - | QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START quotedExpandedLiteralStringContent* QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END - # quotedExpandedExternalCommandLiteral - | symbol - # symbolExpression - | REGULAR_EXPRESSION_START regexpLiteralContent* REGULAR_EXPRESSION_END - # regularExpressionLiteral - | QUOTED_EXPANDED_REGULAR_EXPRESSION_START quotedExpandedLiteralStringContent* QUOTED_EXPANDED_REGULAR_EXPRESSION_END - # quotedExpandedRegularExpressionLiteral - - | LPAREN compoundStatement RPAREN - # groupingStatement - - // Member accesses - | primaryValue LBRACK indexingArgumentList? RBRACK - # indexingAccessExpression - | primaryValue NL* op=(AMPDOT | DOT | COLON2) NL* methodName argumentWithParentheses? block? - # memberAccessExpression - - // Unary and binary expressions - | unaryOperator primaryValue - # unaryExpression - | primaryValue powerOperator=STAR2 NL* primaryValue - # powerExpression - | MINUS primaryValue - # unaryMinusExpression - | primaryValue multiplicativeOperator NL* primaryValue - # multiplicativeExpression - | primaryValue additiveOperator NL* primaryValue - # additiveExpression - | primaryValue bitwiseShiftOperator NL* primaryValue - # shiftExpression - | primaryValue bitwiseAndOperator=AMP NL* primaryValue - # bitwiseAndExpression - | primaryValue bitwiseOrOperator NL* primaryValue - # bitwiseOrExpression - | primaryValue relationalOperator NL* primaryValue - # relationalExpression - | primaryValue equalityOperator NL* primaryValue - # equalityExpression - | primaryValue andOperator=AMP2 NL* primaryValue - # logicalAndExpression - | primaryValue orOperator=BAR2 NL* primaryValue - # logicalOrExpression - | primaryValue rangeOperator NL* primaryValue - # rangeExpression - | hereDoc - # hereDocs - ; - -// This is required to make chained calls work. For classes, we cannot move up the `primaryValue` due to the possible -// presence of AMPDOT when inheriting (class Foo < Bar::Baz), but the command rule doesn't allow chained calls -// in if statements to be created properly, and ends throwing away everything after the first call. Splitting these -// allows us to have a rule for the class that parses properly, and a rule for everything else that allows us to move -// up the `primaryValue` rule to the top. -commandOrPrimaryValueClass - : command - # commandCommandOrPrimaryValueClass - | primaryValue - # primaryValueCommandOrPrimaryValueClass - ; - -commandOrPrimaryValue - : primaryValue - # primaryValueCommandOrPrimaryValue - | command - # commandCommandOrPrimaryValue - | NOT commandOrPrimaryValue - # notCommandOrPrimaryValue - | commandOrPrimaryValue (AND|OR) NL* commandOrPrimaryValue - # keywordAndOrCommandOrPrimaryValue - ; - -block - : LCURLY NL* blockParameter? compoundStatement RCURLY - # curlyBracesBlock - | doBlock - # doBlockBlock - ; - -doBlock - : DO NL* blockParameter? bodyStatement END - ; - -blockParameter - : BAR NL* BAR - | BAR NL* parameterList NL* BAR - ; - -thenClause - : (SEMI | NL)+ compoundStatement - | (SEMI | NL)? THEN compoundStatement - ; - -elseClause - : ELSE compoundStatement - ; - -elsifClause - : ELSIF NL* expressionOrCommand thenClause - ; - -whenClause - : WHEN NL* whenArgument thenClause - ; - -whenArgument - : operatorExpressionList (COMMA splattingArgument)? - | splattingArgument - ; - -doClause - : (SEMI | NL)+ compoundStatement - | DO compoundStatement - ; - -forVariable - : leftHandSide - | multipleLeftHandSide - ; - -bodyStatement - : compoundStatement rescueClause* elseClause? ensureClause? - ; - -rescueClause - : RESCUE exceptionClassList? exceptionVariableAssignment? thenClause - ; - -exceptionClassList - : operatorExpression - | multipleRightHandSide - ; - -exceptionVariableAssignment - : EQGT leftHandSide - ; - -ensureClause - : ENSURE compoundStatement - ; - -definedMethodName - : methodName - | ASSIGNMENT_LIKE_METHOD_IDENTIFIER - | LBRACK RBRACK EQ? - | EQ2 - | EQ3 - | LTEQGT - | LT2 - ; - -methodParameterPart - : LPAREN NL* parameterList? NL* RPAREN - | parameterList? (SEMI | NL) - ; - -parameterList - : mandatoryOrOptionalParameterList (COMMA NL* arrayParameter)? (COMMA NL* hashParameter)? (COMMA NL* procParameter)? - | arrayParameter (COMMA NL* hashParameter)? (COMMA NL* procParameter)? - | hashParameter (COMMA NL* procParameter)? - | procParameter - ; - -mandatoryOrOptionalParameterList - : mandatoryOrOptionalParameter (COMMA NL* mandatoryOrOptionalParameter)* - ; - -mandatoryOrOptionalParameter - : mandatoryParameter - # mandatoryMandatoryOrOptionalParameter - | optionalParameter - # optionalMandatoryOrOptionalParameter - ; - -mandatoryParameter - : LOCAL_VARIABLE_IDENTIFIER COLON? - ; - -optionalParameter - : optionalParameterName (EQ|COLON) NL* operatorExpression - ; - -optionalParameterName - : LOCAL_VARIABLE_IDENTIFIER - ; - -arrayParameter - : STAR LOCAL_VARIABLE_IDENTIFIER? - ; - -hashParameter - : STAR2 LOCAL_VARIABLE_IDENTIFIER? - ; - -procParameter - : AMP procParameterName - ; - -procParameterName - : LOCAL_VARIABLE_IDENTIFIER - ; - -classPath - : COLON2 CONSTANT_IDENTIFIER - # topClassPath - | CONSTANT_IDENTIFIER - # className - | classPath COLON2 CONSTANT_IDENTIFIER - # nestedClassPath - ; - -singletonObject - : variableReference - #variableReferenceSingletonObject - | LPAREN expressionOrCommand RPAREN - #expressionSingletonObject - ; - -variableReference - : variable - # variableVariableReference - | pseudoVariable - # pseudoVariableVariableReference - | COLON2 CONSTANT_IDENTIFIER - # constantVariableReference - ; - -associationList - : association (COMMA NL* association)* - ; - -association - : associationKey (EQGT | COLON) NL* operatorExpression - ; - -associationKey - : operatorExpression - | keyword - ; - -regexpLiteralContent - : REGULAR_EXPRESSION_BODY - | REGULAR_EXPRESSION_INTERPOLATION_BEGIN compoundStatement REGULAR_EXPRESSION_INTERPOLATION_END - ; - -singleQuotedString - : SINGLE_QUOTED_STRING_LITERAL - ; - -singleOrDoubleQuotedString - : singleQuotedString - | doubleQuotedString - ; - -doubleQuotedString - : DOUBLE_QUOTED_STRING_START doubleQuotedStringContent* DOUBLE_QUOTED_STRING_END - ; - -quotedExpandedExternalCommandString - : QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START - quotedExpandedLiteralStringContent* - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END - ; - -doubleQuotedStringContent - : DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE - | STRING_INTERPOLATION_BEGIN compoundStatement STRING_INTERPOLATION_END - ; - -quotedNonExpandedLiteralString - : QUOTED_NON_EXPANDED_STRING_LITERAL_START NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE? QUOTED_NON_EXPANDED_STRING_LITERAL_END - ; - -quotedExpandedLiteralString - : QUOTED_EXPANDED_STRING_LITERAL_START quotedExpandedLiteralStringContent* QUOTED_EXPANDED_STRING_LITERAL_END - ; - -quotedExpandedLiteralStringContent - : EXPANDED_LITERAL_CHARACTER_SEQUENCE - | DELIMITED_STRING_INTERPOLATION_BEGIN compoundStatement DELIMITED_STRING_INTERPOLATION_END - ; - -quotedNonExpandedArrayElementContent - : NON_EXPANDED_ARRAY_ITEM_CHARACTER+ - ; - -quotedExpandedArrayElementContent - : EXPANDED_ARRAY_ITEM_CHARACTER - | DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN compoundStatement DELIMITED_ARRAY_ITEM_INTERPOLATION_END - ; - -quotedExpandedArrayElement - : quotedExpandedArrayElementContent+ - ; - -quotedNonExpandedArrayElementList - : NON_EXPANDED_ARRAY_ITEM_SEPARATOR* - quotedNonExpandedArrayElementContent - (NON_EXPANDED_ARRAY_ITEM_SEPARATOR+ quotedNonExpandedArrayElementContent)* - NON_EXPANDED_ARRAY_ITEM_SEPARATOR* - ; - -quotedExpandedArrayElementList - : EXPANDED_ARRAY_ITEM_SEPARATOR* - quotedExpandedArrayElement - (EXPANDED_ARRAY_ITEM_SEPARATOR+ quotedExpandedArrayElement)* - EXPANDED_ARRAY_ITEM_SEPARATOR* - ; - -symbol - : SYMBOL_LITERAL - # pureSymbolLiteral - | COLON singleQuotedString - # singleQuotedSymbolLiteral - | COLON doubleQuotedString - # doubleQuotedSymbolLiteral - ; - -hereDoc - : HERE_DOC - ; - -// -------------------------------------------------------- -// Commons -// -------------------------------------------------------- - -isDefinedKeyword - : IS_DEFINED - ; - -assignmentOperator - : EQ - | ASSIGNMENT_OPERATOR - ; - -statementModifier - : IF - | UNLESS - | WHILE - | UNTIL - | RESCUE - ; - -variable - : CONSTANT_IDENTIFIER - # constantIdentifierVariable - | GLOBAL_VARIABLE_IDENTIFIER - # globalIdentifierVariable - | CLASS_VARIABLE_IDENTIFIER - # classIdentifierVariable - | INSTANCE_VARIABLE_IDENTIFIER - # instanceIdentifierVariable - | LOCAL_VARIABLE_IDENTIFIER - # localIdentifierVariable - ; - -pseudoVariable - : NIL - # nilPseudoVariable - | TRUE - # truePseudoVariable - | FALSE - # falsePseudoVariable - | SELF - # selfPseudoVariable - | LINE__ - # linePseudoVariable - | FILE__ - # filePseudoVariable - | ENCODING__ - # encodingPseudoVariable - ; - -unsignedNumericLiteral - : DECIMAL_INTEGER_LITERAL - # decimalUnsignedLiteral - | BINARY_INTEGER_LITERAL - # binaryUnsignedLiteral - | OCTAL_INTEGER_LITERAL - # octalUnsignedLiteral - | HEXADECIMAL_INTEGER_LITERAL - # hexadecimalUnsignedLiteral - | FLOAT_LITERAL_WITHOUT_EXPONENT - # floatWithoutExponentUnsignedLiteral - | FLOAT_LITERAL_WITH_EXPONENT - # floatWithExponentUnsignedLiteral - ; - -unaryOperator - : TILDE - | PLUS - | EMARK - ; - -multiplicativeOperator - : STAR - | SLASH - | PERCENT - ; - -additiveOperator - : PLUS - | MINUS - ; - -bitwiseShiftOperator - : LT2 - | GT2 - ; - -bitwiseOrOperator - : BAR - | CARET - ; - -relationalOperator - : GT - | GTEQ - | LT - | LTEQ - ; - -equalityOperator - : LTEQGT - | EQ2 - | EQ3 - | EMARKEQ - | EQTILDE - | EMARKTILDE - ; - -rangeOperator - : DOT2 - | DOT3 - ; - -keyword - : BEGIN_ - | END_ - | ALIAS - | AND - | BEGIN - | BREAK - | CASE - | CLASS - | DEF - | IS_DEFINED - | DO - | ELSE - | ELSIF - | END - | ENSURE - | FOR - | IF - | IN - | MODULE - | NEXT - | NOT - | OR - | REDO - | RESCUE - | RETRY - | RETURN - | SUPER - | THEN - | UNDEF - | UNLESS - | UNTIL - | WHEN - | WHILE - | YIELD - ; \ No newline at end of file diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/resources/application.conf b/joern-cli/frontends/rubysrc2cpg/src/main/resources/application.conf old mode 100644 new mode 100755 index 08cc7c1ee753..b9b734d9b89c --- a/joern-cli/frontends/rubysrc2cpg/src/main/resources/application.conf +++ b/joern-cli/frontends/rubysrc2cpg/src/main/resources/application.conf @@ -1,3 +1,4 @@ rubysrc2cpg { + ruby_ast_gen_version: "0.33.0" joern_type_stubs_version: "0.6.0" } \ No newline at end of file diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/resources/log4j2.xml b/joern-cli/frontends/rubysrc2cpg/src/main/resources/log4j2.xml old mode 100644 new mode 100755 diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/Main.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/Main.scala index f1e97e8e61aa..eec584366e38 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/Main.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/Main.scala @@ -1,31 +1,29 @@ package io.joern.rubysrc2cpg import io.joern.rubysrc2cpg.Frontend.* +import io.joern.x2cpg.astgen.AstGenConfig import io.joern.x2cpg.passes.frontend.{TypeRecoveryParserConfig, XTypeRecovery, XTypeRecoveryConfig} import io.joern.x2cpg.typestub.TypeStubConfig +import io.joern.x2cpg.utils.server.FrontendHTTPServer import io.joern.x2cpg.{DependencyDownloadConfig, X2CpgConfig, X2CpgMain} import scopt.OParser -final case class Config( - antlrCacheMemLimit: Double = 0.6d, - useDeprecatedFrontend: Boolean = false, - downloadDependencies: Boolean = false, - useTypeStubs: Boolean = true -) extends X2CpgConfig[Config] +import java.nio.file.Paths + +final case class Config(downloadDependencies: Boolean = false, useTypeStubs: Boolean = true) + extends X2CpgConfig[Config] with DependencyDownloadConfig[Config] with TypeRecoveryParserConfig[Config] - with TypeStubConfig[Config] { - - this.defaultIgnoredFilesRegex = List("spec", "test", "tests").flatMap { directory => - List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) - } + with TypeStubConfig[Config] + with AstGenConfig[Config] { - def withAntlrCacheMemoryLimit(value: Double): Config = { - copy(antlrCacheMemLimit = value).withInheritedFields(this) - } + override val astGenProgramName: String = "ruby_ast_gen" + override val astGenConfigPrefix: String = "rubysrc2cpg" + override val multiArchitectureBuilds: Boolean = true - def withUseDeprecatedFrontend(value: Boolean): Config = { - copy(useDeprecatedFrontend = value).withInheritedFields(this) + this.defaultIgnoredFilesRegex = List("spec", "tests?", "vendor", "db(\\\\|/)([\\w_]*)migrate([_\\w]*)").flatMap { + directory => + List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) } override def withDownloadDependencies(value: Boolean): Config = { @@ -46,21 +44,9 @@ private object Frontend { import builder.* OParser.sequence( programName("rubysrc2cpg"), - opt[Double]("antlrCacheMemLimit") - .hidden() - .action((x, c) => c.withAntlrCacheMemoryLimit(x)) - .validate { - case x if x < 0.3 => - failure(s"$x may result in too many evictions and reduce performance, try a value between 0.3 - 0.8.") - case x if x > 0.8 => - failure(s"$x may result in too much memory usage and thrashing, try a value between 0.3 - 0.8.") - case x => - success - } - .text("sets the heap usage threshold at which the ANTLR DFA cache is cleared during parsing (default 0.6)"), - opt[Unit]("useDeprecatedFrontend") - .action((_, c) => c.withUseDeprecatedFrontend(true)) - .text("uses the original (but deprecated) Ruby frontend (default false)"), + opt[Unit]("enable-file-content") + .action((_, c) => c.withDisableFileContent(false)) + .text("Enable file content"), DependencyDownloadConfig.parserOptions, XTypeRecoveryConfig.parserOptionsForParserConfig, TypeStubConfig.parserOptions @@ -68,8 +54,12 @@ private object Frontend { } } -object Main extends X2CpgMain(cmdLineParser, new RubySrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new RubySrc2Cpg()) with FrontendHTTPServer[Config, RubySrc2Cpg] { + + override protected def newDefaultConfig(): Config = Config() + def run(config: Config, rubySrc2Cpg: RubySrc2Cpg): Unit = { - rubySrc2Cpg.run(config) + if (config.serverMode) { startup() } + else { rubySrc2Cpg.run(config) } } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/RubySrc2Cpg.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/RubySrc2Cpg.scala index 00b61157a7db..e84b119c3e8b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/RubySrc2Cpg.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/RubySrc2Cpg.scala @@ -2,32 +2,28 @@ package io.joern.rubysrc2cpg import better.files.File import io.joern.rubysrc2cpg.astcreation.AstCreator +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.StatementList import io.joern.rubysrc2cpg.datastructures.RubyProgramSummary -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.parser.RubyParser +import io.joern.rubysrc2cpg.parser.* import io.joern.rubysrc2cpg.passes.{ AstCreationPass, ConfigFileCreationPass, DependencyPass, - DependencySummarySolverPass, - ImplicitRequirePass, - ImportsPass, - RubyImportResolverPass, - RubyTypeHintCallLinker + DependencySummarySolverPass } import io.joern.rubysrc2cpg.utils.DependencyDownloader import io.joern.x2cpg.X2Cpg.withNewEmptyCpg +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.* import io.joern.x2cpg.passes.base.AstLinkerPass import io.joern.x2cpg.passes.callgraph.NaiveCallLinker import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, XTypeRecoveryConfig} import io.joern.x2cpg.utils.{ConcurrentTaskUtil, ExternalCommand} import io.joern.x2cpg.{SourceFiles, X2CpgFrontend} -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.Languages +import io.shiftleft.codepropertygraph.generated.{Cpg, Languages} import io.shiftleft.passes.CpgPassBase import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory +import upickle.default.* import java.nio.file.{Files, Paths} import scala.util.matching.Regex @@ -42,23 +38,27 @@ class RubySrc2Cpg extends X2CpgFrontend[Config] { new MetaDataPass(cpg, Languages.RUBYSRC, config.inputPath).createAndApply() new ConfigFileCreationPass(cpg).createAndApply() new DependencyPass(cpg).createAndApply() - if (config.useDeprecatedFrontend) { - deprecatedCreateCpgAction(cpg, config) - } else { - newCreateCpgAction(cpg, config) - } + createCpgAction(cpg, config) } } - private def newCreateCpgAction(cpg: Cpg, config: Config): Unit = { - Using.resource(new parser.ResourceManagedParser(config.antlrCacheMemLimit)) { parser => + private def createCpgAction(cpg: Cpg, config: Config): Unit = { + File.usingTemporaryDirectory("rubysrc2cpgOut") { tmpDir => + val astGenResult = RubyAstGenRunner(config).execute(tmpDir) + val astCreators = ConcurrentTaskUtil - .runUsingThreadPool(RubySrc2Cpg.generateParserTasks(parser, config, cpg.metaData.root.headOption)) + .runUsingThreadPool( + RubySrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config, cpg.metaData.root.headOption) + ) .flatMap { - case Failure(exception) => logger.warn(s"Could not parse file, skipping - ", exception); None + case Failure(exception) => logger.warn(s"Unable to parse Ruby file, skipping -", exception); None case Success(astCreator) => Option(astCreator) } - // Pre-parse the AST creators for high level structures + .filter(x => { + if x.fileContent.isBlank then logger.info(s"File content empty, skipping - ${x.fileName}") + !x.fileContent.isBlank + }) + val internalProgramSummary = ConcurrentTaskUtil .runUsingThreadPool(astCreators.map(x => () => x.summarize()).iterator) .flatMap { @@ -76,8 +76,6 @@ class RubySrc2Cpg extends X2CpgFrontend[Config] { val programSummary = internalProgramSummary ++= dependencySummary AstCreationPass(cpg, astCreators.map(_.withSummary(programSummary))).createAndApply() - if (cpg.dependency.name.contains("zeitwerk")) ImplicitRequirePass(cpg, programSummary).createAndApply() - ImportsPass(cpg).createAndApply() if config.downloadDependencies then { DependencySummarySolverPass(cpg, dependencySummary).createAndApply() } @@ -85,115 +83,37 @@ class RubySrc2Cpg extends X2CpgFrontend[Config] { } } - private def deprecatedCreateCpgAction(cpg: Cpg, config: Config): Unit = try { - Using.resource(new deprecated.astcreation.ResourceManagedParser(config.antlrCacheMemLimit)) { parser => - if (config.downloadDependencies && !scala.util.Properties.isWin) { - val tempDir = File.newTemporaryDirectory() - try { - downloadDependency(config.inputPath, tempDir.toString()) - new deprecated.passes.AstPackagePass( - cpg, - tempDir.toString(), - parser, - RubySrc2Cpg.packageTableInfo, - config.inputPath - )(config.schemaValidation).createAndApply() - } finally { - tempDir.delete() - } - } - val parsedFiles = { - val tasks = SourceFiles - .determine( - config.inputPath, - RubySrc2Cpg.RubySourceFileExtensions, - ignoredFilesRegex = Option(config.ignoredFilesRegex), - ignoredFilesPath = Option(config.ignoredFiles) - ) - .map(x => - () => - parser.parse(x) match - case Failure(exception) => - logger.warn(s"Could not parse file: $x, skipping", exception); throw exception - case Success(ast) => x -> ast - ) - .iterator - ConcurrentTaskUtil.runUsingThreadPool(tasks).flatMap(_.toOption) - } - - new io.joern.rubysrc2cpg.deprecated.ParseInternalStructures(parsedFiles, cpg.metaData.root.headOption) - .populatePackageTable() - val astCreationPass = - new deprecated.passes.AstCreationPass(cpg, parsedFiles, RubySrc2Cpg.packageTableInfo, config) - astCreationPass.createAndApply() - } - } finally { - RubySrc2Cpg.packageTableInfo.clear() - } - - private def downloadDependency(inputPath: String, tempPath: String): Unit = { - if (Files.isRegularFile(Paths.get(s"${inputPath}${java.io.File.separator}Gemfile"))) { - ExternalCommand.run(s"bundle config set --local path ${tempPath}", inputPath) match { - case Success(configOutput) => - logger.info(s"Gem config successfully done: $configOutput") - case Failure(exception) => - logger.error(s"Error while configuring Gem Path: ${exception.getMessage}") - } - val command = s"bundle install" - ExternalCommand.run(command, inputPath) match { - case Success(bundleOutput) => - logger.info(s"Dependency installed successfully: $bundleOutput") - case Failure(exception) => - logger.error(s"Error while downloading dependency: ${exception.getMessage}") - } - } - } } object RubySrc2Cpg { - // TODO: Global mutable state is bad and should be avoided in the next iteration of the Ruby frontend - val packageTableInfo = new deprecated.utils.PackageTable() - private val RubySourceFileExtensions: Set[String] = Set(".rb") - def postProcessingPasses(cpg: Cpg, config: Config): List[CpgPassBase] = { - if (config.useDeprecatedFrontend) { - List(new deprecated.passes.RubyImportResolverPass(cpg, packageTableInfo)) - ++ new deprecated.passes.RubyTypeRecoveryPassGenerator(cpg).generate() ++ List( - new deprecated.passes.RubyTypeHintCallLinker(cpg), - new NaiveCallLinker(cpg), - - // Some of passes above create new methods, so, we - // need to run the ASTLinkerPass one more time - new AstLinkerPass(cpg) - ) - } else { - List(new RubyImportResolverPass(cpg)) ++ - new passes.RubyTypeRecoveryPassGenerator(cpg, config = XTypeRecoveryConfig(iterations = 4)) - .generate() ++ List(new RubyTypeHintCallLinker(cpg), new NaiveCallLinker(cpg), new AstLinkerPass(cpg)) - } + val implicitRequirePass = if (cpg.dependency.name.contains("zeitwerk")) ImplicitRequirePass(cpg) :: Nil else Nil + implicitRequirePass ++ List(ImportsPass(cpg), RubyImportResolverPass(cpg)) ++ + new RubyTypeRecoveryPassGenerator(cpg, config = XTypeRecoveryConfig(iterations = 4)) + .generate() ++ List(new RubyTypeHintCallLinker(cpg), new NaiveCallLinker(cpg), new AstLinkerPass(cpg)) } - def generateParserTasks( - resourceManagedParser: parser.ResourceManagedParser, + /** Parses the generated AST Gen files in parallel and produces AstCreators from each. + */ + def processAstGenRunnerResults( + astFiles: List[String], config: Config, projectRoot: Option[String] ): Iterator[() => AstCreator] = { - SourceFiles - .determine( - config.inputPath, - RubySourceFileExtensions, - ignoredDefaultRegex = Option(config.defaultIgnoredFilesRegex), - ignoredFilesRegex = Option(config.ignoredFilesRegex), - ignoredFilesPath = Option(config.ignoredFiles) - ) - .map { fileName => () => - resourceManagedParser.parse(File(config.inputPath), fileName) match { - case Failure(exception) => throw exception - case Success(ctx) => new AstCreator(fileName, ctx, projectRoot)(config.schemaValidation) - } - } - .iterator + astFiles.map { fileName => () => + val parserResult = RubyJsonParser.readFile(Paths.get(fileName)) + val rubyProgram = new RubyJsonToNodeCreator().visitProgram(parserResult.json) + val sourceFileName = parserResult.fullPath + val fileContent = File(sourceFileName).contentAsString + new AstCreator( + sourceFileName, + projectRoot, + enableFileContents = !config.disableFileContent, + fileContent = fileContent, + rootNode = rubyProgram + )(config.schemaValidation) + }.iterator } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index ea422185d4e1..98f3f3b653e7 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -1,34 +1,38 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* -import io.joern.rubysrc2cpg.datastructures.{BlockScope, NamespaceScope, RubyProgramSummary, RubyScope, RubyStubbedType} -import io.joern.rubysrc2cpg.parser.{RubyNodeCreator, RubyParser} +import io.joern.rubysrc2cpg.datastructures.{BlockScope, NamespaceScope, RubyProgramSummary, RubyScope} import io.joern.rubysrc2cpg.passes.Defines -import io.joern.x2cpg.utils.NodeBuilders.{newBindingNode, newModifierNode} +import io.joern.rubysrc2cpg.utils.FreshNameGenerator +import io.joern.x2cpg.utils.NodeBuilders.{newModifierNode, newThisParameterNode} import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, ModifierTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{DiffGraphBuilder, ModifierTypes} import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder import java.util.regex.Matcher class AstCreator( val fileName: String, - protected val programCtx: RubyParser.ProgramContext, protected val projectRoot: Option[String] = None, - protected val programSummary: RubyProgramSummary = RubyProgramSummary() + protected val programSummary: RubyProgramSummary = RubyProgramSummary(), + val enableFileContents: Boolean = false, + val fileContent: String = "", + val rootNode: StatementList )(implicit withSchemaValidation: ValidationMode) extends AstCreatorBase(fileName) with AstCreatorHelper with AstForStatementsCreator with AstForExpressionsCreator + with AstForControlStructuresCreator with AstForFunctionsCreator with AstForTypesCreator with AstSummaryVisitor - with AstNodeBuilder[RubyNode, AstCreator] { + with AstNodeBuilder[RubyExpression, AstCreator] { + + val tmpGen: FreshNameGenerator[String] = FreshNameGenerator(i => s"") + val procParamGen: FreshNameGenerator[Left[String, Nothing]] = FreshNameGenerator(i => Left(s"")) /* Used to track variable names and their LOCAL nodes. */ @@ -36,24 +40,27 @@ class AstCreator( protected val logger: Logger = LoggerFactory.getLogger(getClass) + protected var fileNode: Option[NewFile] = None + protected var parseLevel: AstParseLevel = AstParseLevel.FULL_AST + override protected def offset(node: RubyExpression): Option[(Int, Int)] = node.offset + protected val relativeFileName: String = projectRoot .map(fileName.stripPrefix) .map(_.stripPrefix(java.io.File.separator)) .getOrElse(fileName) - private def internalLineAndColNum: Option[Integer] = Option(1) + private def internalLineAndColNum: Option[Int] = Option(1) /** The relative file name, in a unix path delimited format. */ private def relativeUnixStyleFileName = relativeFileName.replaceAll(Matcher.quoteReplacement(java.io.File.separator), "/") - override def createAst(): BatchedUpdate.DiffGraphBuilder = { - val rootNode = new RubyNodeCreator().visit(programCtx).asInstanceOf[StatementList] - val ast = astForRubyFile(rootNode) + override def createAst(): DiffGraphBuilder = { + val ast = astForRubyFile(rootNode) Ast.storeInDiffGraph(ast, diffGraph) diffGraph } @@ -63,8 +70,11 @@ class AstCreator( * allowing for a straightforward representation of out-of-method statements. */ protected def astForRubyFile(rootStatements: StatementList): Ast = { - val fileNode = NewFile().name(relativeFileName) - val fullName = s"$relativeUnixStyleFileName:${NamespaceTraversal.globalNamespaceName}" + fileNode = + if enableFileContents then Option(NewFile().name(relativeFileName).content(fileContent)) + else Option(NewFile().name(relativeFileName)) + val fullName = s"$relativeUnixStyleFileName:${NamespaceTraversal.globalNamespaceName}".stripPrefix("/") + val namespaceBlock = NewNamespaceBlock() .filename(relativeFileName) .name(NamespaceTraversal.globalNamespaceName) @@ -74,13 +84,15 @@ class AstCreator( val rubyFakeMethodAst = astInFakeMethod(rootStatements) scope.popScope() - Ast(fileNode).withChild(Ast(namespaceBlock).withChild(rubyFakeMethodAst)) + Ast(fileNode.get).withChild(Ast(namespaceBlock).withChild(rubyFakeMethodAst)) } private def astInFakeMethod(rootNode: StatementList): Ast = { - val name = Defines.Program - val fullName = computeMethodFullName(name) - val code = rootNode.text + val name = Defines.Main + // From the
method onwards, we do not embed the namespace name in the full names + val fullName = + s"${scope.surroundingScopeFullName.head.stripSuffix(NamespaceTraversal.globalNamespaceName)}$name" + val code = rootNode.text val methodNode_ = methodNode( node = rootNode, name = name, @@ -89,6 +101,15 @@ class AstCreator( signature = None, fileName = relativeFileName ) + val thisParameterNode = newThisParameterNode( + name = Defines.Self, + code = Defines.Self, + typeFullName = Defines.Any, + line = methodNode_.lineNumber, + column = methodNode_.columnNumber + ) + val thisParameterAst = Ast(thisParameterNode) + scope.addToScope(Defines.Self, thisParameterNode) val methodReturn = methodReturnNode(rootNode, Defines.Any) scope.newProgramScope @@ -102,7 +123,7 @@ class AstCreator( scope.popScope() methodAst( methodNode_, - Seq.empty, + thisParameterAst :: Nil, bodyAst, methodReturn, newModifierNode(ModifierTypes.MODULE) :: newModifierNode(ModifierTypes.VIRTUAL) :: Nil diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala index d519c4e1deb8..f507cbb2b1ef 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala @@ -1,11 +1,23 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ ClassFieldIdentifier, + ControlFlowStatement, DummyNode, + ElseClause, + IfExpression, + IndexAccess, InstanceFieldIdentifier, MemberAccess, + MemberCall, + RubyExpression, RubyFieldIdentifier, - RubyNode + SelfIdentifier, + SimpleIdentifier, + SingleAssignment, + StatementList, + StaticLiteral, + TextSpan, + UnaryExpression } import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl} import io.joern.rubysrc2cpg.passes.Defines @@ -15,24 +27,55 @@ import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators} +import scala.collection.mutable + trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def computeClassFullName(name: String): String = s"${scope.surroundingScopeFullName.head}.$name" - protected def computeMethodFullName(name: String): String = s"${scope.surroundingScopeFullName.head}:$name" + private val usedFullNames = mutable.Set.empty[String] - override def column(node: RubyNode): Option[Int] = node.column - override def columnEnd(node: RubyNode): Option[Int] = node.columnEnd - override def line(node: RubyNode): Option[Int] = node.line - override def lineEnd(node: RubyNode): Option[Int] = node.lineEnd - override def code(node: RubyNode): String = shortenCode(node.text) + /** Ensures a unique full name is assigned based on the current scope. + * @param name + * the name of the entity. + * @param counter + * an optional counter, used to create unique instances in the case of redefinitions. + * @param useSurroundingTypeFullName + * flag for whether the fullName is for accessor-like method lowering + * @return + * a unique full name. + */ + protected def computeFullName( + name: String, + counter: Option[Int] = None, + useSurroundingTypeFullName: Boolean = false + ): String = { + val surroundingName = + if useSurroundingTypeFullName then scope.surroundingTypeFullName.head else scope.surroundingScopeFullName.head + val candidate = counter match { + case Some(cnt) => s"$surroundingName.$name$cnt" + case None => s"$surroundingName.$name" + } + if (usedFullNames.contains(candidate)) { + computeFullName(name, counter.map(_ + 1).orElse(Option(0)), useSurroundingTypeFullName) + } else { + usedFullNames.add(candidate) + candidate + } + } + + override def column(node: RubyExpression): Option[Int] = node.column + override def columnEnd(node: RubyExpression): Option[Int] = node.columnEnd + override def line(node: RubyExpression): Option[Int] = node.line + override def lineEnd(node: RubyExpression): Option[Int] = node.lineEnd + + override def code(node: RubyExpression): String = shortenCode(node.text) protected def isBuiltin(x: String): Boolean = kernelFunctions.contains(x) - protected def prefixAsKernelDefined(x: String): String = s"$kernelPrefix$pathSep$x" - protected def prefixAsBundledType(x: String): String = s"${GlobalTypes.builtinPrefix}.$x" + protected def prefixAsKernelDefined(x: String): String = Defines.prefixAsKernelDefined(x) + protected def prefixAsCoreType(x: String): String = Defines.prefixAsCoreType(x) protected def isBundledClass(x: String): Boolean = GlobalTypes.bundledClasses.contains(x) protected def pathSep = "." - private def astForFieldInstance(name: String, node: RubyNode & RubyFieldIdentifier): Ast = { + private def astForFieldInstance(name: String, node: RubyExpression & RubyFieldIdentifier): Ast = { val identName = node match { case _: InstanceFieldIdentifier => Defines.Self case _: ClassFieldIdentifier => scope.surroundingTypeFullName.map(_.split("[.]").last).getOrElse(Defines.Any) @@ -47,7 +90,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As ) } - protected def handleVariableOccurrence(node: RubyNode): Ast = { + protected def handleVariableOccurrence(node: RubyExpression): Ast = { val name = code(node) val identifier = identifierNode(node, name, name, Defines.Any) val typeRef = scope.tryResolveTypeReference(name) @@ -92,12 +135,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As astForAssignment(Ast(lhs), Ast(rhs), lineNumber, columnNumber) } - protected def astForAssignment(lhs: Ast, rhs: Ast, lineNumber: Option[Int], columnNumber: Option[Int]): Ast = { - val code = Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ") + protected def astForAssignment( + lhs: Ast, + rhs: Ast, + lineNumber: Option[Int], + columnNumber: Option[Int], + code: Option[String] = None + ): Ast = { + val _code = + code.getOrElse(Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ")) val assignment = NewCall() .name(Operators.assignment) .methodFullName(Operators.assignment) - .code(code) + .code(_code) .dispatchType(DispatchTypes.STATIC_DISPATCH) .lineNumber(lineNumber) .columnNumber(columnNumber) @@ -116,6 +166,92 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As member } + /** Lowers the `||=` and `&&=` assignment operators to the respective conditional checks, e.g, `aaa ||= bbb` becomes + * `aaa = bbb if !aaa` `aaa &&= bbb` becomes aaa = bbb if aaa` + */ + def lowerAssignmentOperator(lhs: RubyExpression, rhs: RubyExpression, op: String, span: TextSpan): RubyExpression & + ControlFlowStatement = { + val condition = + if op == "||=" then UnaryExpression(op = "!", expression = lhs)(span.spanStart(s"!${lhs.span.text}")) + else lhs + val thenClause = StatementList( + List(SingleAssignment(lhs, "=", rhs)(span.spanStart(s"${lhs.span.text} = ${rhs.span.text}"))) + )(span.spanStart(s"${lhs.span.text} = ${rhs.span.text}")) + IfExpression(condition = condition, thenClause = thenClause, elsifClauses = List.empty, elseClause = None)( + span.spanStart(s"${thenClause.span.text} if ${condition.span.text}") + ) + } + + /** Regex matches implicitly assign values to global variables. Lowering the `=~` operator may look like + * + * { tmp = 'hello'.match(/h(el)lo/); if tmp; $~ = tmp; $& = tmp[0]; tmp.begin(0); else $~= nil; $& = nil; nil end; } + */ + def lowerRegexMatch(target: RubyExpression, regex: RubyExpression, originSpan: TextSpan): RubyExpression = { + // Create tmpName that takes the regex match result + val tmpName = this.tmpGen.fresh + val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) + scope.addToScope(tmpName, tmpGenLocal) match { + case BlockScope(block) => diffGraph.addEdge(block, tmpGenLocal, EdgeTypes.AST) + case _ => + } + def tmp = SimpleIdentifier()(originSpan.spanStart(tmpName)) + + val matchCall = { + val code = s"${regex.text}.match(${target.text})" + MemberCall(regex, ".", "match", target :: Nil)(originSpan.spanStart(code)) + } + val tmpAssignment = { + val code = s"$tmpName = ${matchCall.text}" + SingleAssignment(tmp, "=", matchCall)(originSpan.spanStart(code)) + } + + def self = SelfIdentifier()(originSpan.spanStart(Defines.Self)) + def globalTilde = MemberAccess(self, ".", "$~")(originSpan.spanStart("$~")) + def globalAmpersand = MemberAccess(self, ".", "$&")(originSpan.spanStart("$&")) + + val ifStmt = IfExpression( + condition = tmp, + thenClause = { + val tildeCode = s"$$~ = $tmpName" + val tildeAssign = SingleAssignment(globalTilde, "=", tmp)(originSpan.spanStart(tildeCode)) + + def intLiteral(n: Int) = StaticLiteral(Defines.prefixAsCoreType(Defines.Integer))(originSpan.spanStart(s"$n")) + val tmpIndex0 = IndexAccess(tmp, intLiteral(0) :: Nil)(originSpan.spanStart(s"$tmpName[0]")) + + val ampersandCode = s"$$& = $tmpName[0]" + val ampersandAssign = SingleAssignment(globalAmpersand, "=", tmpIndex0)(originSpan.spanStart(ampersandCode)) + + // use a simple heuristic to determine the N matched groups + val matchGroups = (1 to regex.text.count(_ == '(')).map { idx => + val matchGroupAsgnCode = s"$$$idx = $tmpName[$idx]" + val matchGroup = MemberAccess(self, ".", "$")(originSpan.spanStart("$")) + val matchGroupIndexN = IndexAccess(matchGroup, intLiteral(idx) :: Nil)(originSpan.spanStart(s"$$[$idx]")) + val tmpIndexN = IndexAccess(tmp, intLiteral(idx) :: Nil)(originSpan.spanStart(s"$tmpName[$idx]")) + SingleAssignment(matchGroupIndexN, "=", tmpIndexN)(originSpan.spanStart(matchGroupAsgnCode)) + }.toList + + // tmp.begin(0) is the lowered return value of `~=` + val beginCall = MemberCall(tmp, ".", "begin", intLiteral(0) :: Nil)(originSpan.spanStart(s"$tmpName.begin(0)")) + StatementList(tildeAssign :: ampersandAssign :: Nil ++ matchGroups :+ beginCall)( + originSpan.spanStart(s"$tildeCode; $ampersandCode") + ) + }, + elseClause = Option { + def nil = StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(originSpan.spanStart("nil")) + val tildeCode = s"$$~ = nil" + val tildeAssign = SingleAssignment(globalTilde, "=", nil)(originSpan.spanStart(tildeCode)) + + val ampersandCode = s"$$& = nil" + val ampersandAssign = SingleAssignment(globalAmpersand, "=", nil)(originSpan.spanStart(ampersandCode)) + + val elseSpan = originSpan.spanStart(s"$tildeCode; $ampersandCode; nil") + ElseClause(StatementList(tildeAssign :: ampersandAssign :: nil :: Nil)(elseSpan))(elseSpan) + } + )(originSpan.spanStart(s"if $tmpName ... else ... end")) + + StatementList(tmpAssignment :: ifStmt :: Nil)(originSpan) + } + protected val UnaryOperatorNames: Map[String, String] = Map( "!" -> Operators.logicalNot, "not" -> Operators.logicalNot, @@ -147,8 +283,8 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As "&" -> Operators.and, "|" -> Operators.or, "^" -> Operators.xor, - "<<" -> Operators.shiftLeft, - ">>" -> Operators.logicalShiftRight +// "<<" -> Operators.shiftLeft, Note: Generally Ruby abstracts this as an append operator based on the LHS + ">>" -> Operators.arithmeticShiftRight ) protected val AssignmentOperatorNames: Map[String, String] = Map( @@ -159,8 +295,9 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As "/=" -> Operators.assignmentDivision, "%=" -> Operators.assignmentModulo, "**=" -> Operators.assignmentExponentiation, - // Strictly speaking, `a ||= b` means `a || a = b`, but I reckon we wouldn't gain much representing it that way. - "||=" -> Operators.assignmentOr, - "&&=" -> Operators.assignmentAnd + "|=" -> Operators.assignmentOr, + "&=" -> Operators.assignmentAnd, + "<<=" -> Operators.assignmentShiftLeft, + ">>=" -> Operators.assignmentArithmeticShiftRight ) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala new file mode 100644 index 000000000000..05df8d93111c --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -0,0 +1,357 @@ +package io.joern.rubysrc2cpg.astcreation + +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ + ArrayLiteral, + ArrayPattern, + BinaryExpression, + BreakExpression, + CaseExpression, + ControlFlowStatement, + DoWhileExpression, + DummyAst, + ElseClause, + ForExpression, + IfExpression, + InClause, + IndexAccess, + MatchVariable, + MemberCall, + NextExpression, + OperatorAssignment, + RescueExpression, + RubyExpression, + SimpleIdentifier, + SingleAssignment, + SplattingRubyNode, + StatementList, + StaticLiteral, + UnaryExpression, + Unknown, + UnlessExpression, + UntilExpression, + WhenClause, + WhileExpression +} +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.RubyOperators +import io.joern.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewFieldIdentifier, NewLiteral, NewLocal} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} + +trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => + + protected def astForControlStructureExpression(node: ControlFlowStatement): Ast = node match { + case node: WhileExpression => astForWhileStatement(node) + case node: DoWhileExpression => astForDoWhileStatement(node) + case node: UntilExpression => astForUntilStatement(node) + case node: CaseExpression => blockAst(NewBlock(), astsForCaseExpression(node).toList) + case node: IfExpression => astForIfExpression(node) + case node: UnlessExpression => astForUnlessStatement(node) + case node: ForExpression => astForForExpression(node) + case node: RescueExpression => astForRescueExpression(node) + case node: NextExpression => astForNextExpression(node) + case node: BreakExpression => astForBreakExpression(node) + case node: OperatorAssignment => astForOperatorAssignmentExpression(node) + } + + private def astForWhileStatement(node: WhileExpression): Ast = { + val conditionAst = astForExpression(node.condition) + val bodyAsts = astsForStatement(node.body) + whileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + } + + private def astForDoWhileStatement(node: DoWhileExpression): Ast = { + val conditionAst = astForExpression(node.condition) + val bodyAsts = astsForStatement(node.body) + doWhileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + } + + // `until T do B` is lowered as `while !T do B` + private def astForUntilStatement(node: UntilExpression): Ast = { + val notCondition = astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) + val bodyAsts = astsForStatement(node.body) + whileAst(Some(notCondition), bodyAsts, Option(code(node)), line(node), column(node)) + } + + // Recursively lowers into a ternary conditional call + private def astForIfExpression(node: IfExpression): Ast = { + def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = { + // We want to make sure there's always an «else» clause in a ternary operator. + // The default value is a `nil` literal. + val elseAsts_ = if (elseAsts.isEmpty) { + List(astForNilBlock) + } else { + elseAsts + } + + val call = callNode(node, code(node), Operators.conditional, Operators.conditional, DispatchTypes.STATIC_DISPATCH) + callAst(call, conditionAst :: thenAst :: elseAsts_) + } + + // TODO: Remove or modify the builder pattern when we are no longer using ANTLR + node.elseClause match { + case Some(elseClause) => + elseClause match { + case _: IfExpression => astForJsonIfStatement(node) + case _ => foldIfExpression(builder)(node) + } + case None => + foldIfExpression(builder)(node) + } + } + + private def astForJsonIfStatement(node: IfExpression): Ast = { + val conditionAst = astForExpression(node.condition) + val thenAst = astForThenClause(node.thenClause) + val elseAsts = node.elseClause + .map { + case x: IfExpression => + val wrappedBlock = blockNode(x) + Ast(wrappedBlock).withChild(astForJsonIfStatement(x)) + case x => + astForElseClause(x) + } + .getOrElse(Ast()) + + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst :: elseAsts :: Nil) + } + + // `unless T do B` is lowered as `if !T then B` + private def astForUnlessStatement(node: UnlessExpression): Ast = { + val notConditionAst = astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) + val thenAst = node.trueBranch match + case stmtList: StatementList => astForStatementList(stmtList) + case _ => astForStatementList(StatementList(List(node.trueBranch))(node.trueBranch.span)) + val elseAsts = node.falseBranch.map(astForElseClause).toList + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(notConditionAst), thenAst :: elseAsts) + } + + protected def astForElseClause(node: RubyExpression): Ast = { + node match + case elseNode: ElseClause => + elseNode.thenClause match + case stmtList: StatementList => astForStatementList(stmtList) + case node => + logger.warn(s"Expecting statement list in ${code(node)} ($relativeFileName), skipping") + astForUnknown(node) + case elseNode => + logger.warn(s"Expecting else clause in ${code(elseNode)} ($relativeFileName), skipping") + astForUnknown(elseNode) + } + + private def astForForExpression(node: ForExpression): Ast = { + val forEachNode = controlStructureNode(node, ControlStructureTypes.FOR, code(node)) + + def collectionAst = astForExpression(node.iterableVariable) + val collectionNode = node.iterableVariable + + val iterIdentifier = + identifierNode( + node = node.forVariable, + name = node.forVariable.span.text, + code = node.forVariable.span.text, + typeFullName = Defines.Any + ) + val iterVarLocal = NewLocal().name(node.forVariable.span.text).code(node.forVariable.span.text) + scope.addToScope(node.forVariable.span.text, iterVarLocal) + + val idxName = "_idx_" + val idxLocal = NewLocal().name(idxName).code(idxName).typeFullName(Defines.prefixAsCoreType(Defines.Integer)) + val idxIdenAtAssign = identifierNode( + node = collectionNode, + name = idxName, + code = idxName, + typeFullName = Defines.prefixAsCoreType(Defines.Integer) + ) + + val idxAssignment = + callNode(node, s"$idxName = 0", Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH) + val idxAssignmentArgs = + List(Ast(idxIdenAtAssign), Ast(NewLiteral().code("0").typeFullName(Defines.prefixAsCoreType(Defines.Integer)))) + val idxAssignmentAst = callAst(idxAssignment, idxAssignmentArgs) + + val idxIdAtCond = idxIdenAtAssign.copy + val collectionCountAccess = callNode( + node, + s"${node.iterableVariable.span.text}.length", + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH + ) + val fieldAccessAst = callAst( + collectionCountAccess, + collectionAst :: Ast(NewFieldIdentifier().canonicalName("length").code("length")) :: Nil + ) + + val idxLt = callNode( + node, + s"$idxName < ${node.iterableVariable.span.text}.length", + Operators.lessThan, + Operators.lessThan, + DispatchTypes.STATIC_DISPATCH + ) + val idxLtArgs = List(Ast(idxIdAtCond), fieldAccessAst) + val ltCallCond = callAst(idxLt, idxLtArgs) + + val idxIdAtCollAccess = idxIdenAtAssign.copy + val collectionIdxAccess = callNode( + node, + s"${node.iterableVariable.span.text}[$idxName++]", + Operators.indexAccess, + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH + ) + val postIncrAst = callAst( + callNode(node, s"$idxName++", Operators.postIncrement, Operators.postIncrement, DispatchTypes.STATIC_DISPATCH), + Ast(idxIdAtCollAccess) :: Nil + ) + + val indexAccessAst = callAst(collectionIdxAccess, collectionAst :: postIncrAst :: Nil) + val iteratorAssignmentNode = callNode( + node, + s"${node.forVariable.span.text} = ${node.iterableVariable.span.text}[$idxName++]", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val iteratorAssignmentArgs = List(Ast(iterIdentifier), indexAccessAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + val doBodyAst = astsForStatement(node.doBlock) + + val locals = Ast(idxLocal) + .withRefEdge(idxIdenAtAssign, idxLocal) + .withRefEdge(idxIdAtCond, idxLocal) + .withRefEdge(idxIdAtCollAccess, idxLocal) :: Ast(iterVarLocal).withRefEdge(iterIdentifier, iterVarLocal) :: Nil + + val conditionAsts = ltCallCond :: Nil + val initAsts = idxAssignmentAst :: Nil + val updateAsts = iteratorAssignmentAst :: Nil + + forAst( + forNode = forEachNode, + locals = locals, + initAsts = initAsts, + conditionAsts = conditionAsts, + updateAsts = updateAsts, + bodyAsts = doBodyAst + ) + } + + protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = { + // TODO: Clean up the below + def goCase(expr: Option[SimpleIdentifier]): List[RubyExpression] = { + val elseThenClause: Option[RubyExpression] = node.elseClause.map(_.asInstanceOf[ElseClause].thenClause) + val whenClauses = node.matchClauses.collect { case x: WhenClause => x } + val inClauses = node.matchClauses.collect { case x: InClause => x } + + val ifElseChain = if (whenClauses.nonEmpty) { + whenClauses.foldRight[Option[RubyExpression]](elseThenClause) { + (whenClause: WhenClause, restClause: Option[RubyExpression]) => + // We translate multiple match expressions into an or expression. + // + // A single match expression is compared using `.===` to the case target expression if it is present + // otherwise it is treated as a conditional. + // + // There may be a splat as the last match expression, + // `case y when *x then c end` or + // `case when *x then c end` + // which is translated to `x.include? y` and `x.any?` conditions respectively + + val conditions = whenClause.matchExpressions.map { + case regex: StaticLiteral if regex.typeFullName == prefixAsCoreType(Defines.Regexp) => + expr.map(e => BinaryExpression(regex, RubyOperators.regexpMatch, e)(regex.span)).getOrElse(regex) + case mExpr => + expr.map(e => BinaryExpression(mExpr, "===", e)(mExpr.span)).getOrElse(mExpr) + } ++ whenClause.matchSplatExpression.iterator.flatMap { + case splat @ SplattingRubyNode(exprList) => + expr + .map { e => + List(MemberCall(exprList, ".", "include?", List(e))(splat.span)) + } + .getOrElse { + List(MemberCall(exprList, ".", "any?", List())(splat.span)) + } + case e => + logger.warn(s"Unrecognised RubyNode (${e.getClass}) in case match splat expression") + List(Unknown()(e.span)) + } + // There is always at least one match expression or a splat + // will become an unknown in condition at the end + val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) => + BinaryExpression(cond, "||", condAcc)(whenClause.span) + } + val conditional = IfExpression( + condition, + whenClause.thenClause.asStatementList, + List(), + restClause.map { els => ElseClause(els.asStatementList)(els.span) } + )(node.span) + Some(conditional) + } + } else { + inClauses.foldRight[Option[RubyExpression]](elseThenClause) { + (inClause: InClause, restClause: Option[RubyExpression]) => + val (condition, body) = inClause.pattern match { + case x: ArrayPattern => + val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse(inClause.pattern) + val body = inClause.body + + val stmts = x.children.zipWithIndex.flatMap { + case (lhs: MatchVariable, idx) if expr.isDefined => + val arrAccess = { + val code = s"${expr.get.text}[$idx]" + val base = expr.get.copy()(expr.get.span.spanStart(expr.get.text)) + val indices = StaticLiteral(Defines.prefixAsCoreType(Defines.Integer))( + expr.get.span.spanStart(idx.toString) + ) :: Nil + IndexAccess(base, indices)(lhs.span.spanStart(code)) + } + val asgn = SingleAssignment(lhs, "=", arrAccess)( + inClause.span.spanStart(s"${lhs.span.text} = ${expr.get.text}[$idx]") + ) + Option(asgn) + case _ => None + } :+ body + val conditionBody = StatementList(stmts)(body.span) + + (condition, conditionBody) + case x => + (x, inClause.body) + } + + val conditional = IfExpression( + condition, + body, + List.empty, + restClause.map { els => ElseClause(els.asStatementList)(els.span) } + )(node.span) + Some(conditional) + } + } + ifElseChain.iterator.toList + } + + val caseExpr = node.expression + .map { + case arrayLiteral: ArrayLiteral => + val tmp = SimpleIdentifier(None)(arrayLiteral.span.spanStart(this.tmpGen.fresh)) + val arrayLiteralAst = DummyAst(astForArrayLiteral(arrayLiteral))(arrayLiteral.span) + (tmp, arrayLiteralAst) + case e => + val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh)) + (tmp, e) + } + .map((tmp, e) => StatementList(List(SingleAssignment(tmp, "=", e)(e.span)) ++ goCase(Some(tmp)))(node.span)) + .getOrElse(StatementList(goCase(None))(node.span)) + + astsForStatement(caseExpr) + } + + private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = { + val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span) + astForControlStructureExpression(loweredAssignment) + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 8c1000f388f3..cfd199305cb2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -2,10 +2,10 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{Unknown, Block as RubyBlock, *} import io.joern.rubysrc2cpg.datastructures.BlockScope +import io.joern.rubysrc2cpg.parser.RubyJsonHelpers import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.GlobalTypes -import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType} -import io.joern.rubysrc2cpg.utils.FreshNameGenerator +import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, prefixAsKernelDefined} import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ @@ -17,46 +17,54 @@ import io.shiftleft.codepropertygraph.generated.{ PropertyNames } -trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - val tmpGen: FreshNameGenerator[String] = FreshNameGenerator(i => s"") - - protected def astForExpression(node: RubyNode): Ast = node match - case node: StaticLiteral => astForStaticLiteral(node) - case node: HereDocNode => astForHereDoc(node) - case node: DynamicLiteral => astForDynamicLiteral(node) - case node: UnaryExpression => astForUnary(node) - case node: BinaryExpression => astForBinary(node) - case node: MemberAccess => astForMemberAccess(node) - case node: MemberCall => astForMemberCall(node) - case node: ObjectInstantiation => astForObjectInstantiation(node) - case node: IndexAccess => astForIndexAccess(node) - case node: SingleAssignment => astForSingleAssignment(node) - case node: AttributeAssignment => astForAttributeAssignment(node) - case node: TypeIdentifier => astForTypeIdentifier(node) - case node: RubyIdentifier => astForSimpleIdentifier(node) - case node: SimpleCall => astForSimpleCall(node) - case node: RequireCall => astForRequireCall(node) - case node: IncludeCall => astForIncludeCall(node) - case node: YieldExpr => astForYield(node) - case node: RangeExpression => astForRange(node) - case node: ArrayLiteral => astForArrayLiteral(node) - case node: HashLiteral => astForHashLiteral(node) - case node: Association => astForAssociation(node) - case node: IfExpression => astForIfExpression(node) - case node: UnlessExpression => astForUnlessExpression(node) - case node: RescueExpression => astForRescueExpression(node) - case node: CaseExpression => blockAst(NewBlock(), astsForCaseExpression(node).toList) - case node: MandatoryParameter => astForMandatoryParameter(node) - case node: SplattingRubyNode => astForSplattingRubyNode(node) - case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) - case node: ProcOrLambdaExpr => astForProcOrLambdaExpr(node) - case node: RubyCallWithBlock[_] => astForCallWithBlock(node) - case node: SelfIdentifier => astForSelfIdentifier(node) - case node: BreakStatement => astForBreakStatement(node) - case node: StatementList => astForStatementList(node) - case node: DummyNode => Ast(node.node) - case node: Unknown => astForUnknown(node) +import scala.collection.mutable + +trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { + this: AstCreator => + + /** For tracking aliased calls that occur on the LHS of a member access or call. + */ + protected val baseAstCache = mutable.Map.empty[RubyExpression, String] + + protected def astForExpression(node: RubyExpression): Ast = node match + case node: ControlFlowStatement => astForControlStructureExpression(node) + case node: StaticLiteral => astForStaticLiteral(node) + case node: HereDocNode => astForHereDoc(node) + case node: DynamicLiteral => astForDynamicLiteral(node) + case node: UnaryExpression => astForUnary(node) + case node: BinaryExpression => astForBinary(node) + case node: MemberAccess => astForMemberAccess(node) + case node: MemberCall => astForMemberCall(node) + case node: ObjectInstantiation => astForObjectInstantiation(node) + case node: IndexAccess => astForIndexAccess(node) + case node: SingleAssignment => astForSingleAssignment(node) + case node: AttributeAssignment => astForAttributeAssignment(node) + case node: TypeIdentifier => astForTypeIdentifier(node) + case node: RubyIdentifier => astForSimpleIdentifier(node) + case node: SimpleCall => astForSimpleCall(node) + case node: RequireCall => astForRequireCall(node) + case node: IncludeCall => astForIncludeCall(node) + case node: RaiseCall => astForRaiseCall(node) + case node: YieldExpr => astForYield(node) + case node: RangeExpression => astForRange(node) + case node: ArrayLiteral => astForArrayLiteral(node) + case node: HashLike => astForHashLiteral(node) + case node: Association => astForAssociation(node) + case node: MandatoryParameter => astForMandatoryParameter(node) + case node: SplattingRubyNode => astForSplattingRubyNode(node) + case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) + case node: ProcOrLambdaExpr => astForProcOrLambdaExpr(node) + case node: SingletonObjectMethodDeclaration => astForSingletonObjectMethodDeclaration(node) + case node: RubyCallWithBlock[_] => astForCallWithBlock(node) + case node: SelfIdentifier => astForSelfIdentifier(node) + case node: StatementList => astForStatementList(node) + case node: MultipleAssignment => blockAst(blockNode(node), astsForStatement(node).toList) + case node: ReturnExpression => astForReturnExpression(node) + case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier) + case node: ArrayPattern => astForArrayPattern(node) + case node: DummyNode => Ast(node.node) + case node: DummyAst => node.ast + case node: Unknown => astForUnknown(node) case x => logger.warn(s"Unhandled expression of type ${x.getClass.getSimpleName}") astForUnknown(node) @@ -66,12 +74,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForHereDoc(node: HereDocNode): Ast = { - Ast(literalNode(node, code(node), getBuiltInType("String"))) + Ast(literalNode(node, code(node), prefixAsKernelDefined("String"))) } // Helper for nil literals to put in empty clauses - protected def astForNilLiteral: Ast = Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass))) - protected def astForNilBlock: Ast = blockAst(NewBlock(), List(astForNilLiteral)) + protected def astForNilLiteral: Ast = Ast( + NewLiteral().code("nil").typeFullName(prefixAsKernelDefined(Defines.NilClass)) + ) + + protected def astForNilBlock: Ast = blockAst(NewBlock(), List(astForNilLiteral)) protected def astForDynamicLiteral(node: DynamicLiteral): Ast = { val fmtValueAsts = node.expressions.map { @@ -93,8 +104,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ) astForUnknown(stmtList) case node => - logger.warn(s"Unsupported interpolated literal content: ${code(node)} ($relativeFileName), skipping") - astForUnknown(node) + val call = callNode( + node = node, + code = node.text, + name = Operators.formattedValue, + methodFullName = Operators.formattedValue, + dispatchType = DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Option(Defines.Any) + ) + callAst(call, Seq(astForExpression(node))) } callAst( callNode( @@ -151,19 +170,19 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { /** Attempts to extract a type from the base of a member call. */ - protected def typeFromCallTarget(baseNode: RubyNode): Option[String] = { - scope.lookupVariable(baseNode.text) match { - // fixme: This should be under type recovery logic - case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption - case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => - decl.dynamicTypeHintFullName.headOption + protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = { + baseNode match { + case literal: LiteralExpr => Option(literal.typeFullName) case _ => - astForExpression(baseNode).nodes - .flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString)) - .filterNot(_ == XDefines.Any) - .headOption + scope.lookupVariable(baseNode.text) match { + // fixme: This should be under type recovery logic + case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption + case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => + decl.dynamicTypeHintFullName.headOption + case _ => None + } } } @@ -171,38 +190,54 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { Ast(typeRefNode(node, code(node), node.typeFullName)) } - protected def astForMemberCall(node: MemberCall): Ast = { + protected def astForMemberCall(node: MemberCall, isStatic: Boolean = false): Ast = { def createMemberCall(n: MemberCall): Ast = { - val baseAst = astForExpression(n.target) // this wil be something like self.Foo - val receiverAst = astForExpression(MemberAccess(n.target, ".", n.methodName)(n.span)) + val receiverAst = astForFieldAccess(MemberAccess(n.target, ".", n.methodName)(n.span), stripLeadingAt = true) + val (baseAst, baseCode) = astForMemberAccessTarget(n.target) val builtinType = n.target match { case MemberAccess(_: SelfIdentifier, _, memberName) if isBundledClass(memberName) => - Option(prefixAsBundledType(memberName)) + Option(prefixAsCoreType(memberName)) case x: TypeIdentifier if x.isBuiltin => Option(x.typeFullName) case _ => None } - val (receiverFullName, methodFullName) = receiverAst.nodes + val methodFullName = receiverAst.nodes .collectFirst { - case _ if builtinType.isDefined => builtinType.get -> s"${builtinType.get}:${n.methodName}" - case x: NewMethodRef => x.methodFullName -> x.methodFullName + case _ if builtinType.isDefined => s"${builtinType.get}.${n.methodName}" + case x: NewMethodRef => x.methodFullName case _ => (n.target match { case ma: MemberAccess => scope.tryResolveTypeReference(ma.memberName).map(_.name) case _ => typeFromCallTarget(n.target) - }).map(x => x -> s"$x:${n.methodName}") - .getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName) + }).map(x => s"$x.${n.methodName}") + .getOrElse(XDefines.DynamicCallUnknownFullName) } - .getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName) + .getOrElse(XDefines.DynamicCallUnknownFullName) val argumentAsts = n.arguments.map(astForMethodCallArgument) - val dispatchType = DispatchTypes.DYNAMIC_DISPATCH + val dispatchType = if (isStatic) DispatchTypes.STATIC_DISPATCH else DispatchTypes.DYNAMIC_DISPATCH - val call = callNode(n, code(n), n.methodName, XDefines.DynamicCallUnknownFullName, dispatchType) + val callCode = if (baseCode.contains(" target case x: SimpleIdentifier => scope.getSurroundingType(x.text).map(_.fullName) match { @@ -210,20 +245,114 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val typeName = surroundingType.split('.').last TypeIdentifier(s"$surroundingType")(x.span.spanStart(typeName)) case None if scope.lookupVariable(x.text).isDefined => x - case None => MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span) + case None if x.text.charAt(0).isUpper => // calls have lower-case first character + MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span) + case None => MemberCall(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text, Nil)(x.span) } - case x @ MemberAccess(ma, op, memberName) => x.copy(target = determineMemberAccessBase(ma))(x.span) - case _ => target + case x @ MemberAccess(ma, _, _) => x.copy(target = determineMemberAccessBase(ma))(x.span) + case _ => target } node.target match { + case regex @ StaticLiteral(Defines.Regexp) if node.isRegexMatch => + val loweredRegex = node.arguments.headOption match { + case Some(literal) => lowerRegexMatch(literal, regex, node.span) + case None => + val self = SelfIdentifier()(node.span.spanStart(Defines.Self)) + val globalDefaultString = MemberAccess(self, ".", "$_")(node.span.spanStart("$_")) + lowerRegexMatch(globalDefaultString, regex, node.span) + } + astForExpression(loweredRegex) + // Regex on the RHS is more idiomatic, so no need to check types here. + case literal: LiteralExpr if node.isRegexMatch => + node.arguments.headOption match { + case Some(regex) => astForExpression(lowerRegexMatch(literal, regex, node.span)) + case None => + logger.warn("Regex match with empty argument, defaulting to ordinary member call") + createMemberCall(node) + } + case _: LiteralExpr => + createMemberCall(node) case x: SimpleIdentifier if isBundledClass(x.text) => - createMemberCall(node.copy(target = TypeIdentifier(prefixAsBundledType(x.text))(x.span))(node.span)) + createMemberCall(node.copy(target = TypeIdentifier(prefixAsCoreType(x.text))(x.span))(node.span)) case x: SimpleIdentifier => createMemberCall(node.copy(target = determineMemberAccessBase(x))(node.span)) case memAccess: MemberAccess => createMemberCall(node.copy(target = determineMemberAccessBase(memAccess))(node.span)) - case x => createMemberCall(node) + case _ => createMemberCall(node) + } + } + + protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = { + val (memberName, memberCode) = node.target match { + case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize + case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") + case _: TypeIdentifier => node.memberName -> node.memberName + case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) => + s"@${node.memberName}" -> node.memberName + case _ => node.memberName -> node.memberName + } + + val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode)) + val (targetAst, _code) = astForMemberAccessTarget(node.target) + val code = s"$_code${node.op}$memberCode" + val memberType = typeFromCallTarget(node.target) + .flatMap(scope.tryResolveTypeReference) + .map(_.fields) + .getOrElse(List.empty) + .collectFirst { + case x if x.name == memberName => + scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any) + } + .orElse(Option(Defines.Any)) + val fieldAccess = callNode( + node, + code, + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Option(Defines.Any) + ).possibleTypes(IndexedSeq(memberType.get)) + callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst)) + } + + private def astForMemberAccessTarget(target: RubyExpression): (Ast, String) = { + target match { + case simpleLhs: (LiteralExpr | SimpleIdentifier | SelfIdentifier | TypeIdentifier) => + astForExpression(simpleLhs) -> code(target) + case target: MemberAccess => handleTmpGen(target, astForFieldAccess(target, stripLeadingAt = true)) + case target => handleTmpGen(target, astForExpression(target)) + } + } + + private def handleTmpGen(target: RubyExpression, rhs: Ast): (Ast, String) = { + // Check cache + val createAssignmentToTmp = !baseAstCache.contains(target) + val tmpName = baseAstCache + .updateWith(target) { + case Some(tmpName) => + // TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST, + // we want to clean this creation up. + Option(tmpName) + case None => + val tmpName = this.tmpGen.fresh + val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) + scope.addToScope(tmpName, tmpGenLocal) match { + case BlockScope(block) => diffGraph.addEdge(block, tmpGenLocal, EdgeTypes.AST) + case _ => + } + Option(tmpName) + } + .get + val tmpIden = NewIdentifier().name(tmpName).code(tmpName).typeFullName(Defines.Any) + val tmpIdenAst = + scope.lookupVariable(tmpName).map(x => Ast(tmpIden).withRefEdge(tmpIden, x)).getOrElse(Ast(tmpIden)) + val code = s"$tmpName = ${target.text}" + if (createAssignmentToTmp) { + astForAssignment(tmpIdenAst, rhs, target.line, target.column, Option(code)) -> s"($code)" + } else { + tmpIdenAst -> s"($code)" } } @@ -249,33 +378,60 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { expr } .getOrElse(defaultBehaviour) + case None if node.indices.isEmpty => + astForExpression(MemberCall(node.target, ".", "[]", node.indices)(node.span)) case None => defaultBehaviour } } - protected def astForObjectInstantiation(node: RubyNode & ObjectInstantiation): Ast = { - val className = node.target.text - val callName = "new" - val methodName = Defines.Initialize + /* `foo() do end` is lowered as a METHOD node shaped like so: + * ``` + * = def 0() + * + * end + * foo(, ) + * ``` + */ + protected def astForCallWithBlock[C <: RubyCall](node: RubyExpression & RubyCallWithBlock[C]): Ast = { + val Seq(typeRef, _) = astForDoBlock(node.block): @unchecked + val typeRefDummyNode = typeRef.root.map(DummyNode(_)(node.span)).toList + + // Create call with argument referencing the MethodRef + val callWithLambdaArg = node.withoutBlock match { + case x: SimpleCall => astForSimpleCall(x.copy(arguments = x.arguments ++ typeRefDummyNode)(x.span)) + case x: MemberCall => astForMemberCall(x.copy(arguments = x.arguments ++ typeRefDummyNode)(x.span)) + case x => + logger.warn(s"Unhandled call-with-block type ${code(x)}, creating anonymous method structures only") + Ast() + } + + callWithLambdaArg + } + + protected def astForObjectInstantiation(node: RubyExpression & ObjectInstantiation): Ast = { /* We short-cut the call edge from `new` call to `initialize` method, however we keep the modelling of the receiver as referring to the singleton class. */ - val (receiverTypeFullName, fullName) = scope.tryResolveTypeReference(className) match { - case Some(typeMetaData) => s"${typeMetaData.name}" -> s"${typeMetaData.name}:$methodName" - case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName + val (receiverTypeFullName, fullName) = node.target match { + case x: (SimpleIdentifier | MemberAccess) => + scope.tryResolveTypeReference(x.text) match { + case Some(typeMetaData) => s"${typeMetaData.name}" -> s"${typeMetaData.name}.${Defines.Initialize}" + case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName + } + case _ => XDefines.Any -> XDefines.DynamicCallUnknownFullName } /* Similarly to some other frontends, we lower the constructor into two operations, e.g., `return Bar.new`, lowered to `return {Bar tmp = Bar.(); tmp.(); tmp}` */ - val block = blockNode(node) + val block = blockNode(node, node.text, Defines.Any) scope.pushNewScope(BlockScope(block)) - val tmpName = tmpGen.fresh + val tmpName = this.tmpGen.fresh val tmpTypeHint = receiverTypeFullName.stripSuffix("") - val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpName)) + val tmp = SimpleIdentifier(None)(node.span.spanStart(tmpName)) val tmpLocal = NewLocal().name(tmpName).code(tmpName).dynamicTypeHintFullName(Seq(tmpTypeHint)) scope.addToScope(tmpName, tmpLocal) @@ -286,12 +442,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } // Assign tmp to - val receiverAst = Ast(identifierNode(node, className, className, receiverTypeFullName)) - val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH) - val allocAst = callAst(allocCall, Seq.empty, Option(receiverAst)) + val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH) + val allocAst = callAst(allocCall, Seq.empty) val assignmentCall = callNode( node, - s"${tmp.text} = ${code(node)}", + s"${tmp.text} = ${code(node.target)}.${Defines.Initialize}", Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH @@ -302,12 +457,21 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val argumentAsts = node match { case x: SimpleObjectInstantiation => x.arguments.map(astForMethodCallArgument) case x: ObjectInstantiationWithBlock => - val Seq(_, methodRef) = astForDoBlock(x.block): @unchecked - x.arguments.map(astForMethodCallArgument) :+ methodRef + val Seq(typeRef, _) = astForDoBlock(x.block): @unchecked + x.arguments.map(astForMethodCallArgument) :+ typeRef } - val constructorCall = callNode(node, code(node), callName, fullName, DispatchTypes.DYNAMIC_DISPATCH) - val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier)) + val constructorCall = + callNode( + node, + code(node), + Defines.Initialize, + XDefines.DynamicCallUnknownFullName, + DispatchTypes.DYNAMIC_DISPATCH + ) + if fullName != XDefines.DynamicCallUnknownFullName then constructorCall.dynamicTypeHintFullName(Seq(fullName)) + val constructorRecv = astForExpression(MemberAccess(node.target, ".", Defines.Initialize)(node.span)) + val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier), Option(constructorRecv)) val retIdentifierAst = tmpIdentifier scope.popScope() @@ -327,27 +491,36 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForUnknown(node) case Some(op) => node.rhs match { - case cfNode: ControlFlowExpression => + case cfNode: ControlFlowStatement => def elseAssignNil(span: TextSpan) = Option { ElseClause( StatementList( SingleAssignment( node.lhs, node.op, - StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) + StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(span.spanStart("nil")) )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) )(span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend")) } - def transform(e: RubyNode & ControlFlowExpression): RubyNode = + def transform(e: RubyExpression & ControlFlowStatement): RubyExpression = transformLastRubyNodeInControlFlowExpressionBody( e, x => reassign(node.lhs, node.op, x, transform), elseAssignNil ) - astForExpression(transform(cfNode)) + + cfNode match { + case x @ OperatorAssignment(lhs, op, rhs) => + val loweredNode = lowerAssignmentOperator(lhs, rhs, op, x.span) + astForExpression(transform(loweredNode)) + case x => + astForExpression(transform(cfNode)) + } + case _ => + val rhsAst = astForExpression(node.rhs) // The if the LHS defines a new variable, put the local variable into scope val lhsAst = node.lhs match { case x: SimpleIdentifier if scope.lookupVariable(code(x)).isEmpty => @@ -358,9 +531,25 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case _ => } astForExpression(node.lhs) + case SplattingRubyNode(nameNode: SimpleIdentifier) if scope.lookupVariable(code(nameNode)).isEmpty => + val name = code(nameNode) + val local = localNode(nameNode, name, name, Defines.Any) + scope.addToScope(name, local) match { + case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST) + case _ => + } + astForExpression(node.lhs) + case x: GroupedParameter => + val asts = astsForStatement(x.multipleAssignment) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + return callAst(call, asts :+ rhsAst) + case x: MatchVariable => + handleVariableOccurrence(x.toSimpleIdentifier) // Create local variable under this scope + val matchIden = astForExpression(x.toSimpleIdentifier) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + return callAst(call, matchIden :: rhsAst :: Nil) case _ => astForExpression(node.lhs) } - val rhsAst = astForExpression(node.rhs) // If this is a simple object instantiation assignment, we can give the LHS variable a type hint if (node.rhs.isInstanceOf[ObjectInstantiation] && lhsAst.root.exists(_.isInstanceOf[NewIdentifier])) { @@ -386,51 +575,60 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def reassign( - lhs: RubyNode, + lhs: RubyExpression, op: String, - rhs: RubyNode, - transform: (RubyNode & ControlFlowExpression) => RubyNode - ): RubyNode = { - def stmtListAssigningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match { - case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil - case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil + rhs: RubyExpression, + transform: (RubyExpression & ControlFlowStatement) => RubyExpression + ): RubyExpression = { + def stmtListAssigningLastExpression(stmts: List[RubyExpression]): List[RubyExpression] = stmts match { + case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil + case (head: ControlFlowStatement) :: Nil => transform(head) :: Nil case head :: Nil => SingleAssignment(lhs, op, head)(rhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}")) :: Nil case Nil => List.empty case head :: tail => head :: stmtListAssigningLastExpression(tail) } - def clauseAssigningLastExpression(x: RubyNode & ControlFlowClause): RubyNode = x match { + def clauseAssigningLastExpression(x: RubyExpression & ControlFlowClause): RubyExpression = x match { case RescueClause(exceptionClassList, assignment, thenClause) => RescueClause(exceptionClassList, assignment, reassign(lhs, op, thenClause, transform))(x.span) case EnsureClause(thenClause) => EnsureClause(reassign(lhs, op, thenClause, transform))(x.span) case ElsIfClause(condition, thenClause) => ElsIfClause(condition, reassign(lhs, op, thenClause, transform))(x.span) - case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause, transform))(x.span) + case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause, transform))(x.span) + case InClause(pattern, body) => InClause(pattern, reassign(lhs, op, body, transform))(x.span) case WhenClause(matchExpressions, matchSplatExpression, thenClause) => WhenClause(matchExpressions, matchSplatExpression, reassign(lhs, op, thenClause, transform))(x.span) } rhs match { - case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span) - case clause: ControlFlowClause => clauseAssigningLastExpression(clause) - case expr: ControlFlowExpression => transform(expr) + case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span) + case clause: ControlFlowClause => clauseAssigningLastExpression(clause) + case expr: ControlFlowStatement => transform(expr) case _ => SingleAssignment(lhs, op, rhs)(rhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}")) } } - // `x.y = 1` is lowered as `x.y=(1)`, i.e. as calling `y=` on `x` with argument `1` + // `x.y = 1` is approximated as `x.y = 1`, i.e. as calling `x.y =` assignment with argument `1` + // This has the benefit of avoiding unnecessary call resolution protected def astForAttributeAssignment(node: AttributeAssignment): Ast = { - val call = SimpleCall(node, List(node.rhs))(node.span) - val memberAccess = MemberAccess(node.target, ".", s"${node.attributeName}=")(node.span) - astForMemberCallWithoutBlock(call, memberAccess) + val memberAccess = MemberAccess(node.target, ".", s"@${node.attributeName}")( + node.span.spanStart(s"${node.target.text}.${node.attributeName}") + ) + + val assignmentOp = AssignmentOperatorNames(node.assignmentOperator) + + val lhsAst = astForFieldAccess(memberAccess, stripLeadingAt = true) + val rhsAst = astForExpression(node.rhs) + val call = callNode(node, code(node), assignmentOp, assignmentOp, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(lhsAst, rhsAst)) } - protected def astForSimpleIdentifier(node: RubyNode & RubyIdentifier): Ast = { + protected def astForSimpleIdentifier(node: RubyExpression & RubyIdentifier): Ast = { val name = code(node) if (isBundledClass(name)) { - val typeFullName = prefixAsBundledType(name) + val typeFullName = prefixAsCoreType(name) Ast(typeRefNode(node, typeFullName, typeFullName)) } else { scope.lookupVariable(name) match { @@ -445,12 +643,25 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } - protected def astForMandatoryParameter(node: RubyNode): Ast = handleVariableOccurrence(node) + protected def astForArrayPattern(node: ArrayPattern): Ast = { + val callNode_ = + callNode(node, code(node), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH) + val childrenAst = node.children.map { + case x: MatchVariable if scope.lookupVariable(x.text).isEmpty => handleVariableOccurrence(x.toSimpleIdentifier) + case x: MatchVariable => astForExpression(x.toSimpleIdentifier) + case x => astForExpression(x) + } + + callAst(callNode_, childrenAst) + } + + protected def astForMandatoryParameter(node: RubyExpression): Ast = handleVariableOccurrence(node) protected def astForSimpleCall(node: SimpleCall): Ast = { node.target match - case targetNode: SimpleIdentifier => astForMethodCallWithoutBlock(node, targetNode) - case targetNode: MemberAccess => astForMemberCallWithoutBlock(node, targetNode) + case targetNode: SimpleIdentifier => astForMethodCallWithoutBlock(node, targetNode) + case targetNode: RubyFieldIdentifier => astForMemberCallWithoutBlock(node, targetNode.toMemberAccess) + case targetNode: MemberAccess => astForMemberCallWithoutBlock(node, targetNode) case targetNode => logger.warn(s"Unrecognized target of call: ${targetNode.text} ($relativeFileName), skipping") astForUnknown(targetNode) @@ -462,7 +673,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case _ => None } pathOpt.foreach(path => scope.addRequire(projectRoot.get, fileName, path, node.isRelative, node.isWildCard)) - astForSimpleCall(node.asSimpleCall) + + val callName = node.target.text + val requireCallNode = NewCall() + .name(node.target.text) + .code(code(node)) + .methodFullName(prefixAsKernelDefined(callName)) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + val arguments = astForExpression(node.argument) :: Nil + callAst(requireCallNode, arguments) } protected def astForIncludeCall(node: IncludeCall): Ast = { @@ -472,28 +692,31 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForSimpleCall(node.asSimpleCall) } - /** A yield in Ruby could either return the result of the block, or simply call the block, depending on runtime - * conditions. Thus we embed this in a conditional expression where the condition itself is some non-deterministic - * placeholder. + protected def astForRaiseCall(node: RaiseCall): Ast = { + val throwControlStruct = controlStructureNode(node, ControlStructureTypes.THROW, code(node)) + val args = node.arguments.map(astForExpression) + Ast(throwControlStruct).withChildren(args) + } + + /** A yield in Ruby calls an explicit (or implicit) proc parameter and returns its value. This can be lowered as + * block.call(), which is effectively how one invokes a proc parameter in any case. */ protected def astForYield(node: YieldExpr): Ast = { scope.useProcParam match { case Some(param) => - val call = astForExpression( - SimpleCall(SimpleIdentifier()(node.span.spanStart(param)), node.arguments)(node.span.spanStart(param)) - ) - val ret = returnAst(returnNode(node, code(node))) - val cond = astForExpression( - SimpleCall(SimpleIdentifier()(node.span.spanStart(tmpGen.fresh)), List())(node.span.spanStart("")) - ) - callAst( - callNode(node, code(node), Operators.conditional, Operators.conditional, DispatchTypes.STATIC_DISPATCH), - List(cond, call, ret) - ) + // We do not know if we necessarily have an explicit proc param here, or if we need to create a new one + if (scope.lookupVariable(param).isEmpty) { + scope.anonProcParam.map { param => + val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param")) + astForParameter(paramNode, -1) + } + } + val loweredCall = + MemberCall(SimpleIdentifier()(node.span.spanStart(param)), ".", "call", node.arguments)(node.span) + astForExpression(loweredCall) case None => logger.warn(s"Yield expression outside of method scope: ${code(node)} ($relativeFileName), skipping") astForUnknown(node) - } } @@ -505,39 +728,60 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForArrayLiteral(node: ArrayLiteral): Ast = { - if (node.isDynamic) { - logger.warn(s"Interpolated array literals are not supported yet: ${code(node)} ($relativeFileName), skipping") - astForUnknown(node) + val arrayInitCall = { + val base = SimpleIdentifier()(node.span.spanStart(Defines.Array)) + astForExpression(SimpleObjectInstantiation(base, Nil)(node.span)) + } + if (node.elements.isEmpty) { + arrayInitCall } else { + val tmp = this.tmpGen.fresh + + def tmpRubyNode(tmpNode: Option[RubyExpression] = None) = + SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp)) + + def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier(tmpRubyNode(tmpNode)) + + val block = blockNode(node, node.text, Defines.Any) + scope.pushNewScope(BlockScope(block)) + val tmpLocal = NewLocal().name(tmp).code(tmp) + scope.addToScope(tmp, tmpLocal) + val arguments = if (node.text.startsWith("%")) { val argumentsType = - if (node.isStringArray) getBuiltInType(Defines.String) - else getBuiltInType(Defines.Symbol) + if (node.isStringArray) prefixAsCoreType(Defines.String) + else prefixAsCoreType(Defines.Symbol) node.elements.map { - case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span) - case element => element + case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span) + case element @ DynamicLiteral(_, expressions) => DynamicLiteral(argumentsType, expressions)(element.span) + case element => element } } else { node.elements } - val argumentAsts = arguments.map(astForExpression) + val argumentAsts = arguments.zipWithIndex.map { case (arg, idx) => + val indices = StaticLiteral(Defines.prefixAsCoreType(Defines.Integer))(arg.span.spanStart(idx.toString)) :: Nil + val base = tmpRubyNode(Option(arg)) + val indexAccess = IndexAccess(base, indices)(arg.span.spanStart(s"${base.text}[$idx]")) + val assignment = + SingleAssignment(indexAccess, "=", arg)(arg.span.spanStart(s"${indexAccess.text} = ${arg.text}")) + astForExpression(assignment) + } - val call = - callNode( - node, - code(node), - Operators.arrayInitializer, - Operators.arrayInitializer, - DispatchTypes.STATIC_DISPATCH - ) - callAst(call, argumentAsts) + val assignment = + callNode(node, code(node), Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH) + val tmpAssignment = callAst(assignment, tmpAst() :: arrayInitCall :: Nil) + val tmpRetAst = tmpAst(node.elements.lastOption) + + scope.popScope() + blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst) } } - protected def astForHashLiteral(node: HashLiteral): Ast = { - val tmp = tmpGen.fresh + protected def astForHashLiteral(node: HashLike): Ast = { + val tmp = this.tmpGen.fresh - def tmpAst(tmpNode: Option[RubyNode] = None) = astForSimpleIdentifier( + def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier( SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp)) ) @@ -548,7 +792,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val argumentAsts = node.elements.flatMap(elem => elem match - case associationNode: Association => astForAssociationHash(associationNode, tmp) + case associationNode: Association => astForAssociationHash(associationNode, tmp) + case splattingRubyNode: SplattingRubyNode => astForSplattingRubyNode(splattingRubyNode) :: Nil case node => logger.warn(s"Could not represent element: ${code(node)} ($relativeFileName), skipping") astForUnknown(node) :: Nil @@ -573,6 +818,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForAssociationHash(node: Association, tmp: String): List[Ast] = { node.key match { + case mod: AccessModifier => + // Modifiers aren't allowed here, will be shadowed by a simple identifier + astForAssociationHash(node.copy(key = mod.toSimpleIdentifier)(node.span), tmp) + case iden: SimpleIdentifier => + // An identifier here will always be interpreted as a symbol + val sym = StaticLiteral(Defines.prefixAsCoreType(Defines.Symbol))(iden.span.spanStart(s":${iden.text}")) + astForAssociationHash(node.copy(key = sym)(node.span), tmp) case rangeExpr: RangeExpression => val expandedList = generateStaticLiteralsForRange(rangeExpr).map { x => astForSingleKeyValue(x, node.value, tmp) @@ -583,42 +835,35 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } else { astForSingleKeyValue(node.key, node.value, tmp) :: Nil } - case _ => astForSingleKeyValue(node.key, node.value, tmp) :: Nil } } protected def generateStaticLiteralsForRange(node: RangeExpression): List[StaticLiteral] = { (node.lowerBound, node.upperBound) match { - case (lb: StaticLiteral, ub: StaticLiteral) => - (lb.typeFullName, ub.typeFullName) match { - case (s"${GlobalTypes.`kernelPrefix`}.Integer", s"${GlobalTypes.`kernelPrefix`}.Integer") => - generateRange(lb.span.text.toInt, ub.span.text.toInt, node.rangeOperator.exclusive) - .map(x => - StaticLiteral(lb.typeFullName)(TextSpan(lb.line, lb.column, lb.lineEnd, lb.columnEnd, x.toString)) - ) - .toList - case (s"${GlobalTypes.`kernelPrefix`}.String", s"${GlobalTypes.`kernelPrefix`}.String") => - val lbVal = lb.span.text.replaceAll("['\"]", "") - val ubVal = ub.span.text.replaceAll("['\"]", "") - - // TODO: Also might need to check if one is upper case and other is lower, since in Ruby this would not - // create any range but it might with this impl of using ASCII values. - if (lbVal.length > 1 || ubVal.length > 1) { - // Not simulating the case where we have something like "ab"..."ad" - return List.empty - } - - generateRange(lbVal(0).toInt, ubVal(0).toInt, node.rangeOperator.exclusive) - .map(x => - StaticLiteral(lb.typeFullName)( - TextSpan(lb.line, lb.column, lb.lineEnd, lb.columnEnd, s"\'${x.toChar.toString}\'") - ) - ) - .toList - case _ => - List.empty + case (lb @ StaticLiteral(Defines.Integer), ub @ StaticLiteral(Defines.Integer)) => + generateRange(lb.span.text.toInt, ub.span.text.toInt, node.rangeOperator.exclusive) + .map(x => + StaticLiteral(lb.typeFullName)(TextSpan(lb.line, lb.column, lb.lineEnd, lb.columnEnd, None, x.toString)) + ) + .toList + case (lb @ StaticLiteral(Defines.String), ub @ StaticLiteral(Defines.String)) => + val lbVal = lb.span.text.replaceAll("['\"]", "") + val ubVal = ub.span.text.replaceAll("['\"]", "") + + // TODO: Also might need to check if one is upper case and other is lower, since in Ruby this would not + // create any range but it might with this impl of using ASCII values. + if (lbVal.length > 1 || ubVal.length > 1) { + // Not simulating the case where we have something like "ab"..."ad" + return List.empty } + generateRange(lbVal(0).toInt, ubVal(0).toInt, node.rangeOperator.exclusive) + .map(x => + StaticLiteral(lb.typeFullName)( + TextSpan(lb.line, lb.column, lb.lineEnd, lb.columnEnd, None, s"\'${x.toChar.toString}\'") + ) + ) + .toList case _ => List.empty } @@ -637,13 +882,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { callAst(call, Seq(key, value)) } - protected def astForSingleKeyValue(keyNode: RubyNode, valueNode: RubyNode, tmp: String): Ast = { + protected def astForSingleKeyValue(keyNode: RubyExpression, valueNode: RubyExpression, tmp: String): Ast = { astForExpression( SingleAssignment( IndexAccess( - SimpleIdentifier()(TextSpan(keyNode.line, keyNode.column, keyNode.lineEnd, keyNode.columnEnd, tmp)), + SimpleIdentifier()(TextSpan(keyNode.line, keyNode.column, keyNode.lineEnd, keyNode.columnEnd, None, tmp)), List(keyNode) - )(TextSpan(keyNode.line, keyNode.column, keyNode.lineEnd, keyNode.columnEnd, s"$tmp[${keyNode.span.text}]")), + )( + TextSpan( + keyNode.line, + keyNode.column, + keyNode.lineEnd, + keyNode.columnEnd, + None, + s"$tmp[${keyNode.span.text}]" + ) + ), "=", valueNode )( @@ -652,34 +906,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { keyNode.column, keyNode.lineEnd, keyNode.columnEnd, + None, s"$tmp[${keyNode.span.text}] = ${valueNode.span.text}" ) ) ) } - // Recursively lowers into a ternary conditional call - protected def astForIfExpression(node: IfExpression): Ast = { - def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = { - // We want to make sure there's always an «else» clause in a ternary operator. - // The default value is a `nil` literal. - val elseAsts_ = if (elseAsts.isEmpty) { - List(astForNilBlock) - } else { - elseAsts - } - - val call = callNode(node, code(node), Operators.conditional, Operators.conditional, DispatchTypes.STATIC_DISPATCH) - callAst(call, conditionAst :: thenAst :: elseAsts_) - } - foldIfExpression(builder)(node) - } - - protected def astForUnlessExpression(node: UnlessExpression): Ast = { - val notConditionAst = UnaryExpression("!", node.condition)(node.condition.span) - astForExpression(IfExpression(notConditionAst, node.trueBranch, List(), node.falseBranch)(node.span)) - } - protected def astForRescueExpression(node: RescueExpression): Ast = { val tryAst = astForStatementList(node.body.asStatementList) val rescueAsts = node.rescueClauses @@ -696,27 +929,35 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case x: NewMethodParameterIn => Ast(x.dynamicTypeHintFullName(classes)) } .toList - astForStatementList(x.thenClause.asStatementList).withChildren(variables) + val rescueNode = controlStructureNode(x.thenClause.asStatementList, ControlStructureTypes.CATCH, "catch") + Ast(rescueNode).withChild(astForStatementList(x.thenClause.asStatementList).withChildren(variables)) } - val elseAst = node.elseClause.map { x => astForStatementList(x.thenClause.asStatementList) } - val ensureAst = node.ensureClause.map { x => astForStatementList(x.thenClause.asStatementList) } - tryCatchAstWithOrder( - NewControlStructure() - .controlStructureType(ControlStructureTypes.TRY) - .code(code(node)), - tryAst, - rescueAsts ++ elseAst.toSeq, - ensureAst - ) + val elseAst = node.elseClause.map { x => + val astForClause = controlStructureNode(x.thenClause.asStatementList, ControlStructureTypes.ELSE, "else") + Ast(astForClause).withChild(astForStatementList(x.thenClause.asStatementList)) + } + + val ensureAst = node.ensureClause.map { x => + val astForEnsureClause = + controlStructureNode(x.thenClause.asStatementList, ControlStructureTypes.FINALLY, "finally") + Ast(astForEnsureClause).withChild(astForStatementList(x.thenClause.asStatementList)) + } + + val tryNode = controlStructureNode(node.body.asStatementList, ControlStructureTypes.TRY, "try") + tryCatchAst(tryNode, tryAst, rescueAsts ++ elseAst, ensureAst) } private def astForSelfIdentifier(node: SelfIdentifier): Ast = { val thisIdentifier = identifierNode(node, Defines.Self, code(node), scope.surroundingTypeFullName.getOrElse(Defines.Any)) - Ast(thisIdentifier) + + scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(thisIdentifier).withRefEdge(thisIdentifier, selfParam)) + .getOrElse(Ast(thisIdentifier)) } - protected def astForUnknown(node: RubyNode): Ast = { + protected def astForUnknown(node: RubyExpression): Ast = { val className = node.getClass.getSimpleName val text = code(node) logger.warn(s"Could not represent expression: $text ($className) ($relativeFileName), skipping") @@ -724,16 +965,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = { - val receiverAst = astForFieldAccess(memberAccess) - val methodName = memberAccess.memberName - // TODO: Type recovery should potentially resolve this - val methodFullName = typeFromCallTarget(memberAccess.target) - .map(x => s"$x:$methodName") - .getOrElse(XDefines.DynamicCallUnknownFullName) - val argumentAsts = node.arguments.map(astForMethodCallArgument) - val call = - callNode(node, code(node), methodName, XDefines.DynamicCallUnknownFullName, DispatchTypes.DYNAMIC_DISPATCH) - .possibleTypes(IndexedSeq(methodFullName)) + val receiverAst = astForFieldAccess(memberAccess) + val methodName = memberAccess.memberName + val methodFullName = XDefines.DynamicCallUnknownFullName + val argumentAsts = node.arguments.map(astForMethodCallArgument) + val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH) callAst(call, argumentAsts, Some(receiverAst)) } @@ -753,41 +989,60 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ) // Check if this is a method invocation of a member imported into scope match { case Some(m) => - scope.typeForMethod(m).map(t => t.name -> s"${t.name}:${m.name}").getOrElse(defaultResult) + scope.typeForMethod(m).map(t => t.name -> s"${t.name}.${m.name}").getOrElse(defaultResult) case None => defaultResult } val argumentAst = node.arguments.map(astForMethodCallArgument) val (dispatchType, methodFullName) = - if receiverType.startsWith(GlobalTypes.builtinPrefix) then (DispatchTypes.STATIC_DISPATCH, methodFullNameHint) + if receiverType.startsWith(GlobalTypes.corePrefix) then (DispatchTypes.STATIC_DISPATCH, methodFullNameHint) else (DispatchTypes.DYNAMIC_DISPATCH, XDefines.DynamicCallUnknownFullName) val call = callNode(node, code(node), methodName, methodFullName, dispatchType) if methodFullName != methodFullNameHint then call.possibleTypes(IndexedSeq(methodFullNameHint)) - val receiverAst = astForExpression( - MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", call.name)(node.span) + val receiverAst = astForFieldAccess( + MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", call.name)(node.span), + stripLeadingAt = true ) - val baseAst = Ast(identifierNode(node, Defines.Self, Defines.Self, receiverType)) + val selfIdentifier = identifierNode(node, Defines.Self, Defines.Self, receiverType) + val baseAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(selfIdentifier).withRefEdge(selfIdentifier, selfParam)) + .getOrElse(Ast(selfIdentifier)) callAst(call, argumentAst, Option(baseAst), Option(receiverAst)) } private def astForProcOrLambdaExpr(node: ProcOrLambdaExpr): Ast = { - val Seq(_, methodRef) = astForDoBlock(node.block): @unchecked - methodRef + val Seq(typeRef, _) = astForDoBlock(node.block): @unchecked + typeRef + } + + private def astForSingletonObjectMethodDeclaration(node: SingletonObjectMethodDeclaration): Ast = { + val methodAstsWithRefs = astForMethodDeclaration(node, isSingletonObjectMethod = true) + + // Set span contents + methodAstsWithRefs.flatMap(_.nodes).foreach { + case m: NewMethodRef => DummyNode(m.copy)(node.body.span.spanStart(m.code)) + case _ => + } + + val Seq(typeRef, _) = methodAstsWithRefs + + typeRef } - private def astForMethodCallArgument(node: RubyNode): Ast = { + private def astForMethodCallArgument(node: RubyExpression): Ast = { node match // Associations in method calls are keyword arguments case assoc: Association => astForKeywordArgument(assoc) case block: RubyBlock => - val Seq(methodDecl, typeDecl, _, methodRef) = astForDoBlock(block) + val Seq(methodDecl, typeDecl, typeRef, _) = astForDoBlock(block) Ast.storeInDiffGraph(methodDecl, diffGraph) Ast.storeInDiffGraph(typeDecl, diffGraph) - methodRef + typeRef case selfMethod: SingletonMethodDeclaration => // Last element is the method declaration, the prefix methods would be `foo = def foo (...)` pointers in other // contexts, but this would be empty as a method call argument @@ -799,59 +1054,42 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } Ast.storeInDiffGraph(methodDeclAst, diffGraph) scope.surroundingScopeFullName - .map(s => Ast(methodRefNode(node, selfMethod.span.text, s"$s:${selfMethod.methodName}", Defines.Any))) + .map(s => Ast(methodRefNode(node, selfMethod.span.text, s"$s.${selfMethod.methodName}", Defines.Any))) .getOrElse(Ast()) case _ => astForExpression(node) } private def astForKeywordArgument(assoc: Association): Ast = { + + def setArgumentName(argumentAst: Ast, name: String): Ast = { + argumentAst.root.collectFirst { case x: ExpressionNew => + x.argumentName_=(Option(name)) + x.argumentIndex_=(-1) + } + argumentAst + } + val value = astForExpression(assoc.value) - assoc.key match - case keyIdentifier: SimpleIdentifier => - value.root.collectFirst { case x: ExpressionNew => - x.argumentName_=(Option(keyIdentifier.text)) - x.argumentIndex_=(-1) - } - value - case _: StaticLiteral => astForExpression(assoc) + assoc.key match { + case keyIdentifier: SimpleIdentifier => setArgumentName(value, keyIdentifier.text) + case symbol @ StaticLiteral(Defines.Symbol) => setArgumentName(value, symbol.text.stripPrefix(":")) + case _: (LiteralExpr | RubyCall | ProcOrLambdaExpr | MemberAccess | IndexAccess) => astForExpression(assoc) case x => logger.warn(s"Not explicitly handled argument association key of type ${x.getClass.getSimpleName}") astForExpression(assoc) - } - - protected def astForFieldAccess(node: MemberAccess): Ast = { - val fieldIdentifierAst = Ast(fieldIdentifierNode(node, node.memberName, node.memberName)) - val targetAst = astForExpression(node.target) - val code = s"${node.target.text}${node.op}${node.memberName}" - val memberType = typeFromCallTarget(node.target) - .flatMap(scope.tryResolveTypeReference) - .map(_.fields) - .getOrElse(List.empty) - .collectFirst { - case x if x.name == node.memberName => - scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any) - } - .orElse(Option(Defines.Any)) - val fieldAccess = callNode( - node, - code, - Operators.fieldAccess, - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - signature = None, - typeFullName = Option(Defines.Any) - ).possibleTypes(IndexedSeq(memberType.get)) - callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst)) + } } protected def astForSplattingRubyNode(node: SplattingRubyNode): Ast = { val splattingCall = callNode(node, code(node), RubyOperators.splat, RubyOperators.splat, DispatchTypes.STATIC_DISPATCH) - val argumentAst = astsForStatement(node.name) + val argumentAst = astsForStatement(node.target) callAst(splattingCall, argumentAst) } - private def getBinaryOperatorName(op: String): Option[String] = BinaryOperatorNames.get(op) - private def getUnaryOperatorName(op: String): Option[String] = UnaryOperatorNames.get(op) + private def getBinaryOperatorName(op: String): Option[String] = BinaryOperatorNames.get(op) + + private def getUnaryOperatorName(op: String): Option[String] = UnaryOperatorNames.get(op) + private def getAssignmentOperatorName(op: String): Option[String] = AssignmentOperatorNames.get(op) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala index 4bdc09643f52..d28a5a039ae9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -26,7 +26,10 @@ import scala.collection.mutable trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - val procParamGen = FreshNameGenerator(i => Left(s"")) + /** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of creation. So we + * assume every reference to this map means that the closure AST was successfully propagated. + */ + protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[NewNode]] /** Creates method declaration related structures. * @param node @@ -36,12 +39,34 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th * @return * a method declaration with additional refs and types if specified. */ - protected def astForMethodDeclaration(node: MethodDeclaration, isClosure: Boolean = false): Seq[Ast] = { + protected def astForMethodDeclaration( + node: RubyExpression & ProcedureDeclaration, + isClosure: Boolean = false, + isSingletonObjectMethod: Boolean = false, + useSurroundingTypeFullName: Boolean = false + ): Seq[Ast] = { val isInTypeDecl = scope.surroundingAstLabel.contains(NodeTypes.TYPE_DECL) val isConstructor = (node.methodName == Defines.Initialize) && isInTypeDecl val methodName = node.methodName - // TODO: body could be a try - val fullName = computeMethodFullName(methodName) + + val fullName = + node match { + case x: SingletonObjectMethodDeclaration => + computeFullName( + s"class<<${x.baseClass.span.text}.$methodName", + useSurroundingTypeFullName = useSurroundingTypeFullName + ) + case _ => computeFullName(methodName, useSurroundingTypeFullName = useSurroundingTypeFullName) + } + + val astParentType = + if useSurroundingTypeFullName || shouldUseSurroundingTypeFullName then Some(NodeTypes.TYPE_DECL) + else scope.surroundingAstLabel + + val astParentFullName = + if useSurroundingTypeFullName || shouldUseSurroundingTypeFullName then scope.surroundingTypeFullName + else scope.surroundingScopeFullName + val method = methodNode( node = node, name = methodName, @@ -49,33 +74,38 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th code = code(node), signature = None, fileName = relativeFileName, - astParentType = scope.surroundingAstLabel, - astParentFullName = scope.surroundingScopeFullName + astParentType = astParentType, + astParentFullName = astParentFullName ) val isSurroundedByProgramScope = scope.isSurroundedByProgramScope - if (isConstructor) scope.pushNewScope(ConstructorScope(fullName)) - else scope.pushNewScope(MethodScope(fullName, procParamGen.fresh)) - - val thisParameterAst = Ast( - newThisParameterNode( - code = Defines.Self, - typeFullName = scope.surroundingTypeFullName.getOrElse(Defines.Any), - line = method.lineNumber, - column = method.columnNumber - ) + if (isConstructor) scope.pushNewScope(ConstructorScope(fullName, this.procParamGen.fresh)) + else scope.pushNewScope(MethodScope(fullName, this.procParamGen.fresh)) + + val thisParameterNode = newThisParameterNode( + name = Defines.Self, + code = Defines.Self, + typeFullName = scope.surroundingTypeFullName.getOrElse(Defines.Any), + line = method.lineNumber, + column = method.columnNumber ) + val thisParameterAst = Ast(thisParameterNode) + scope.addToScope(Defines.Self, thisParameterNode) val parameterAsts = thisParameterAst :: astForParameters(node.parameters) val optionalStatementList = statementListForOptionalParams(node.parameters) val methodReturn = methodReturnNode(node, Defines.Any) - val refs = - List(typeRefNode(node, methodName, fullName), methodRefNode(node, methodName, fullName, fullName)).map(Ast.apply) + val refs = { + val typeRef = + if isClosure then typeRefNode(node, s"$methodName&Proc", s"$fullName&Proc") + else typeRefNode(node, methodName, fullName) + List(typeRef, methodRefNode(node, methodName, fullName, fullName)).map(Ast.apply) + } // Consider which variables are captured from the outer scope - val stmtBlockAst = if (isClosure) { + val stmtBlockAst = if (isClosure || isSingletonObjectMethod) { val baseStmtBlockAst = astForMethodBody(node.body, optionalStatementList) transformAsClosureBody(refs, baseStmtBlockAst) } else { @@ -90,43 +120,57 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } // For yield statements where there isn't an explicit proc parameter - val anonProcParam = scope.anonProcParam.map { param => - val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param")) + val anonProcParam = scope.procParamName.map { p => val nextIndex = - parameterAsts.lastOption.flatMap(_.root).map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(0) - astForParameter(paramNode, nextIndex) + parameterAsts.flatMap(_.root).lastOption.map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(0) + + Ast(p.index(nextIndex)) } scope.popScope() val methodTypeDeclAst = { val typeDeclNode_ = typeDeclNode(node, methodName, fullName, relativeFileName, code(node)) - scope.surroundingAstLabel.foreach(typeDeclNode_.astParentType(_)) - scope.surroundingScopeFullName.foreach(typeDeclNode_.astParentFullName(_)) + astParentType.foreach(typeDeclNode_.astParentType(_)) + astParentFullName.foreach(typeDeclNode_.astParentFullName(_)) createMethodTypeBindings(method, typeDeclNode_) - if isClosure then - Ast(typeDeclNode_) - .withChild(Ast(newModifierNode(ModifierTypes.LAMBDA))) - .withChild( - // This member refers back to itself, as itself is the type decl bound to the respective method - Ast(NewMember().name("call").code("call").dynamicTypeHintFullName(Seq(fullName)).typeFullName(Defines.Any)) - ) + if isClosure then Ast(typeDeclNode_).withChild(Ast(newModifierNode(ModifierTypes.LAMBDA))) else Ast(typeDeclNode_) } - val modifiers = mutable.Buffer(ModifierTypes.VIRTUAL) + // Due to lambdas being invoked by `call()`, this additional type ref holding that member is created. + val lambdaTypeDeclAst = if isClosure then { + val typeDeclNode_ = typeDeclNode(node, s"$methodName&Proc", s"$fullName&Proc", relativeFileName, code(node)) + astParentType.foreach(typeDeclNode_.astParentType(_)) + astParentFullName.foreach(typeDeclNode_.astParentFullName(_)) + Ast(typeDeclNode_) + .withChild( + // This member refers back to itself, as itself is the type decl bound to the respective method + Ast(NewMember().name("call").code("call").dynamicTypeHintFullName(Seq(fullName)).typeFullName(Defines.Any)) + ) + } else Ast() + + val accessModifier = + // Initialize is guaranteed `private` by the Ruby interpreter (we include our method here) + if (methodName == Defines.Initialize || methodName == Defines.TypeDeclBody) ModifierTypes.PRIVATE + //
functions are private functions on the Object class + else if (isSurroundedByProgramScope) ModifierTypes.PRIVATE + // Else, use whatever modifier has been user-defined (or is default for current scope) + else currentAccessModifier + val modifiers = mutable.Buffer(ModifierTypes.VIRTUAL, accessModifier) if (isClosure) modifiers.addOne(ModifierTypes.LAMBDA) if (isConstructor) modifiers.addOne(ModifierTypes.CONSTRUCTOR) val prefixMemberAst = - if isClosure || isSurroundedByProgramScope then Ast() // program scope members are set elsewhere + if isClosure || isSingletonObjectMethod || isSurroundedByProgramScope then + Ast() // program scope members are set elsewhere else { // Singleton constructors that initialize @@ fields should have their members linked under the singleton class val methodMember = scope.surroundingTypeFullName match { case Some(astParentTfn) => memberForMethod(method, Option(NodeTypes.TYPE_DECL), Option(astParentTfn)) case None => memberForMethod(method, scope.surroundingAstLabel, scope.surroundingScopeFullName) } - Ast(memberForMethod(method, scope.surroundingAstLabel, scope.surroundingScopeFullName)) + Ast(memberForMethod(method, Option(NodeTypes.TYPE_DECL), astParentFullName)) } // For closures, we also want the method/type refs for upstream use val methodAst_ = { @@ -141,19 +185,48 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } // Each of these ASTs are linked via AstLinker as per the astParent* properties - (prefixMemberAst :: methodAst_ :: methodTypeDeclAst :: Nil).foreach(Ast.storeInDiffGraph(_, diffGraph)) + (prefixMemberAst :: methodAst_ :: methodTypeDeclAst :: lambdaTypeDeclAst :: Nil) + .foreach(Ast.storeInDiffGraph(_, diffGraph)) // In the case of a closure, we expect this method to return a method ref, otherwise, we bind a pointer to a // method ref, e.g. self.foo = def foo(...) - if isClosure then refs else createMethodRefPointer(method) :: Nil + if isClosure || isSingletonObjectMethod then refs else createMethodRefPointer(method) :: Nil + } + + protected def astForMethodAccessModifier(node: MethodAccessModifier): Seq[Ast] = { + val originalAccessModifier = currentAccessModifier + popAccessModifier() + + node match { + case _: PrivateMethodModifier => + pushAccessModifier(ModifierTypes.PRIVATE) + case _: PublicMethodModifier => + pushAccessModifier(ModifierTypes.PUBLIC) + } + + val methodAst = node.method match { + case m: ProcedureDeclaration => astsForStatement(m) + case x => + // Not sure how we should represent dynamically setting access modifiers based on method refs + logger.debug(s"Unhandled method reference from AST type ${x.getClass}") + Nil + } + + popAccessModifier() + pushAccessModifier(originalAccessModifier) + + methodAst } private def transformAsClosureBody(refs: List[Ast], baseStmtBlockAst: Ast) = { // Determine which locals are captured val capturedLocalNodes = baseStmtBlockAst.nodes - .collect { case x: NewIdentifier => x } + .collect { case x: NewIdentifier if x.name != Defines.Self => x } // Self identifiers are handled separately .distinctBy(_.name) - .flatMap(i => scope.lookupVariable(i.name)) + .map(i => scope.lookupVariableInOuterScope(i.name)) + .filter(_.nonEmpty) + .flatten .toSet + val capturedIdentifiers = baseStmtBlockAst.nodes.collect { case i: NewIdentifier if capturedLocalNodes.map(_.name).contains(i.name) => i } @@ -163,30 +236,42 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th case _ => false }) - val methodRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewMethodRef => x } + val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } + val astChildren = mutable.Buffer.empty[NewNode] + val refEdges = mutable.Buffer.empty[(NewNode, NewNode)] + val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)] capturedLocalNodes .collect { case local: NewLocal => - val closureBindingId = scope.surroundingScopeFullName.map(x => s"$x:${local.name}") + val closureBindingId = scope.variableScopeFullName(local.name).map(x => s"$x.${local.name}") (local, local.name, local.code, closureBindingId) case param: NewMethodParameterIn => - val closureBindingId = scope.surroundingScopeFullName.map(x => s"$x:${param.name}") + val closureBindingId = scope.variableScopeFullName(param.name).map(x => s"$x.${param.name}") (param, param.name, param.code, closureBindingId) } - .collect { case (decl, name, code, Some(closureBindingId)) => - val local = newLocalNode(name, code, Option(closureBindingId)) - val closureBinding = newClosureBindingNode(closureBindingId, name, EvaluationStrategies.BY_REFERENCE) + .collect { case (capturedLocal, name, code, Some(closureBindingId)) => + val capturingLocal = + newLocalNode(name = name, typeFullName = Defines.Any, closureBindingId = Option(closureBindingId)) + + val closureBinding = newClosureBindingNode( + closureBindingId = closureBindingId, + originalName = name, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE + ) // Create new local node for lambda, with corresponding REF edges to identifiers and closure binding - capturedBlockAst.withChild(Ast(local)) - capturedIdentifiers.filter(_.name == name).foreach(i => capturedBlockAst.withRefEdge(i, local)) - diffGraph.addEdge(closureBinding, decl, EdgeTypes.REF) + val _refEdges = + capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ (closureBinding, capturedLocal) - methodRefOption.foreach(methodRef => diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE)) + astChildren.addOne(capturingLocal) + refEdges.addAll(_refEdges.toList) + captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList) } - capturedBlockAst + val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => ast.withChild(Ast(child)) } + val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => ast.withRefEdge(src, dst) } + captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => ast.withCaptureEdge(src, dst) } } /** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports. @@ -198,7 +283,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } // TODO: remaining cases - protected def astForParameter(node: RubyNode, index: Int): Ast = { + protected def astForParameter(node: RubyExpression, index: Int): Ast = { node match { case node: (MandatoryParameter | OptionalParameter) => val parameterIn = parameterInNode( @@ -222,23 +307,35 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th evaluationStrategy = EvaluationStrategies.BY_REFERENCE, typeFullName = None ) - scope.addToScope(node.name, parameterIn) - scope.setProcParam(node.name) - Ast(parameterIn) + scope.setProcParam(node.name, parameterIn) + Ast() // The proc parameter is retrieved later under method AST creation case node: CollectionParameter => val typeFullName = node match { - case ArrayParameter(_) => prefixAsKernelDefined("Array") - case HashParameter(_) => prefixAsKernelDefined("Hash") + case ArrayParameter(_) => prefixAsCoreType("Array") + case HashParameter(_) => prefixAsCoreType("Hash") } + val name = node.name.stripPrefix("*") val parameterIn = parameterInNode( node = node, - name = node.name, + name = name, code = code(node), index = index, isVariadic = true, evaluationStrategy = EvaluationStrategies.BY_REFERENCE, typeFullName = Option(typeFullName) ) + scope.addToScope(name, parameterIn) + Ast(parameterIn) + case node: GroupedParameter => + val parameterIn = parameterInNode( + node = node.tmpParam, + name = node.name, + code = code(node.tmpParam), + index = index, + isVariadic = false, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE, + typeFullName = None + ) scope.addToScope(node.name, parameterIn) Ast(parameterIn) case node => @@ -249,11 +346,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } - private def generateTextSpan(node: RubyNode, text: String): TextSpan = { - TextSpan(node.span.line, node.span.column, node.span.lineEnd, node.span.columnEnd, text) + private def generateTextSpan(node: RubyExpression, text: String): TextSpan = { + TextSpan(node.span.line, node.span.column, node.span.lineEnd, node.span.columnEnd, node.span.offset, text) } - protected def statementForOptionalParam(node: OptionalParameter): RubyNode = { + protected def statementForOptionalParam(node: OptionalParameter): RubyExpression = { val defaultExprNode = node.defaultExpression IfExpression( @@ -301,7 +398,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th protected def astForSingletonMethodDeclaration(node: SingletonMethodDeclaration): Seq[Ast] = { node.target match { case targetNode: SingletonMethodIdentifier => - val fullName = computeMethodFullName(node.methodName) + val fullName = computeFullName(node.methodName) val (astParentType, astParentFullName, thisParamCode, addEdge) = targetNode match { case _: SelfIdentifier => @@ -321,7 +418,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } - scope.pushNewScope(MethodScope(fullName, procParamGen.fresh)) + scope.pushNewScope(MethodScope(fullName, this.procParamGen.fresh)) val method = methodNode( node = node, name = node.methodName, @@ -332,40 +429,46 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th ) val methodTypeDecl_ = typeDeclNode(node, node.methodName, fullName, relativeFileName, code(node)) val methodTypeDeclAst = Ast(methodTypeDecl_) - astParentType.orElse(scope.surroundingAstLabel).foreach { t => - methodTypeDecl_.astParentType(t) - method.astParentType(t) - } - astParentFullName.orElse(scope.surroundingScopeFullName).foreach { fn => - methodTypeDecl_.astParentFullName(fn) - method.astParentFullName(fn) - } createMethodTypeBindings(method, methodTypeDecl_) - val thisParameterAst = Ast( - newThisParameterNode( - name = Defines.Self, - code = thisParamCode, - typeFullName = astParentFullName.getOrElse(Defines.Any), - line = method.lineNumber, - column = method.columnNumber - ) + val thisNodeTypeFullName = astParentFullName match { + case Some(fn) => s"$fn" + case None => Defines.Any + } + + val thisNode = newThisParameterNode( + name = Defines.Self, + code = thisParamCode, + typeFullName = thisNodeTypeFullName, + line = method.lineNumber, + column = method.columnNumber ) + val thisParameterAst = Ast(thisNode) + scope.addToScope(Defines.Self, thisNode) - val parameterAsts = astForParameters(node.parameters) + val parameterAsts = thisParameterAst :: astForParameters(node.parameters) val optionalStatementList = statementListForOptionalParams(node.parameters) val stmtBlockAst = astForMethodBody(node.body, optionalStatementList) - val anonProcParam = scope.anonProcParam.map { param => - val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param")) + val anonProcParam = scope.procParamName.map { p => val nextIndex = - parameterAsts.lastOption.flatMap(_.root).map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(1) - astForParameter(paramNode, nextIndex) + parameterAsts.flatMap(_.root).lastOption.map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(0) + + Ast(p.index(nextIndex)) } scope.popScope() + astParentType.orElse(scope.surroundingAstLabel).foreach { t => + methodTypeDecl_.astParentType(t) + method.astParentType(t) + } + astParentFullName.orElse(scope.surroundingScopeFullName).foreach { fn => + methodTypeDecl_.astParentFullName(fn) + method.astParentFullName(fn) + } + // The member for these types refers to the singleton class val member = memberForMethod(method, Option(NodeTypes.TYPE_DECL), astParentFullName.map(x => s"$x")) diffGraph.addNode(member) @@ -373,10 +476,10 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th val _methodAst = methodAst( method, - (thisParameterAst +: parameterAsts) ++ anonProcParam, + parameterAsts ++ anonProcParam, stmtBlockAst, methodReturnNode(node, Defines.Any), - newModifierNode(ModifierTypes.VIRTUAL) :: Nil + newModifierNode(ModifierTypes.VIRTUAL) :: newModifierNode(currentAccessModifier) :: Nil ) _methodAst :: methodTypeDeclAst :: Nil foreach (Ast.storeInDiffGraph(_, diffGraph)) @@ -417,7 +520,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th .methodFullName(Operators.fieldAccess) .dispatchType(DispatchTypes.STATIC_DISPATCH) .typeFullName(Defines.Any) - callAst(fieldAccess, Seq(Ast(self), Ast(fi))) + val selfAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(self).withRefEdge(self, selfParam)) + .getOrElse(Ast(self)) + callAst(fieldAccess, Seq(selfAst, Ast(fi))) } astForAssignment(methodRefIdent, methodRefNode, method.lineNumber, method.columnNumber) @@ -426,24 +533,24 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } - private def astForParameters(parameters: List[RubyNode]): List[Ast] = { + private def astForParameters(parameters: List[RubyExpression]): List[Ast] = { parameters.zipWithIndex.map { case (parameterNode, index) => astForParameter(parameterNode, index + 1) } } - private def statementListForOptionalParams(params: List[RubyNode]): StatementList = { + private def statementListForOptionalParams(params: List[RubyExpression]): StatementList = { StatementList( params .collect { case x: OptionalParameter => x } .map(statementForOptionalParam) - )(TextSpan(None, None, None, None, "")) + )(TextSpan(None, None, None, None, None, "")) } private def astForMethodBody( - body: RubyNode, + body: RubyExpression, optionalStatementList: StatementList, returnLastExpression: Boolean = true ): Ast = { @@ -472,7 +579,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } - private def astForConstructorMethodBody(body: RubyNode, optionalStatementList: StatementList): Ast = { + private def astForConstructorMethodBody(body: RubyExpression, optionalStatementList: StatementList): Ast = { if (this.parseLevel == AstParseLevel.SIGNATURES) { Ast() } else { @@ -490,4 +597,27 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th } } + private val accessModifierStack: mutable.Stack[String] = mutable.Stack.empty + + protected def currentAccessModifier: String = { + accessModifierStack.headOption.getOrElse(ModifierTypes.PUBLIC) + } + + protected def pushAccessModifier(name: String): Unit = { + accessModifierStack.push(name) + } + + protected def popAccessModifier(): Unit = { + if (accessModifierStack.nonEmpty) accessModifierStack.pop() + } + + private def shouldUseSurroundingTypeFullName: Boolean = { + val inBodyMethodScope = + scope.surroundingScopeFullName.exists(x => x.split("[.]").takeRight(1).contains(Defines.TypeDeclBody)) + + scope.surroundingAstLabel match { + case Some(NodeTypes.METHOD) => inBodyMethodScope + case _ => false + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 68504e6cd561..f8fa937dc4ec 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -2,60 +2,100 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* import io.joern.rubysrc2cpg.datastructures.BlockScope +import io.joern.rubysrc2cpg.parser.RubyJsonHelpers import io.joern.rubysrc2cpg.passes.Defines -import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType +import io.joern.rubysrc2cpg.passes.Defines.prefixAsKernelDefined +import io.joern.x2cpg.datastructures.MethodLike import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewMethod, NewMethodRef, NewTypeDecl} +import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewControlStructure} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, ModifierTypes, NodeTypes} trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def astsForStatement(node: RubyNode): Seq[Ast] = node match - case node: WhileExpression => astForWhileStatement(node) :: Nil - case node: DoWhileExpression => astForDoWhileStatement(node) :: Nil - case node: UntilExpression => astForUntilStatement(node) :: Nil - case node: IfExpression => astForIfStatement(node) :: Nil - case node: UnlessExpression => astForUnlessStatement(node) :: Nil - case node: ForExpression => astForForExpression(node) :: Nil - case node: CaseExpression => astsForCaseExpression(node) - case node: StatementList => astForStatementList(node) :: Nil - case node: SimpleCallWithBlock => astForCallWithBlock(node) :: Nil - case node: MemberCallWithBlock => astForCallWithBlock(node) :: Nil - case node: ReturnExpression => astForReturnStatement(node) :: Nil - case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil - case node: TypeDeclaration => astForClassDeclaration(node) - case node: FieldsDeclaration => astsForFieldDeclarations(node) - case node: MethodDeclaration => astForMethodDeclaration(node) - case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node) - case node: MultipleAssignment => node.assignments.map(astForExpression) - case node: BreakStatement => astForBreakStatement(node) :: Nil - case _ => astForExpression(node) :: Nil - - private def astForWhileStatement(node: WhileExpression): Ast = { - val conditionAst = astForExpression(node.condition) - val bodyAsts = astsForStatement(node.body) - whileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + protected def astsForStatement(node: RubyExpression): Seq[Ast] = { + baseAstCache.clear() // A safe approximation on where to reset the cache + node match { + case node: IfExpression => astForIfStatement(node) + case node: OperatorAssignment => astForOperatorAssignment(node) + case node: CaseExpression => astsForCaseExpression(node) + case node: StatementList => astForStatementList(node) :: Nil + case node: ReturnExpression => astForReturnExpression(node) :: Nil + case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil + case node: TypeDeclaration => astForClassDeclaration(node) + case node: FieldsDeclaration => astsForFieldDeclarations(node) + case node: AccessModifier => astForAccessModifier(node) + case node: MethodDeclaration => astForMethodDeclaration(node) + case node: MethodAccessModifier => astForMethodAccessModifier(node) + case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node) + case node: MultipleAssignment => node.assignments.map(astForExpression) + case node: BreakExpression => astForBreakExpression(node) :: Nil + case node: SingletonStatementList => astForSingletonStatementList(node) + case node: AliasStatement => astForAliasStatement(node) + case _ => astForExpression(node) :: Nil + } + } + + private def astForIfStatement(node: IfExpression): Seq[Ast] = { + def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = { + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst :: elseAsts) + } + + // TODO: Remove or modify the builder pattern when we are no longer using ANTLR + node.elseClause match { + case Some(elseClause) => + elseClause match { + case _: IfExpression => astForJsonIfStatement(node) + case _ => foldIfExpression(builder)(node) :: Nil + } + case None => + foldIfExpression(builder)(node) :: Nil + } + } + + private def astForOperatorAssignment(node: OperatorAssignment): Seq[Ast] = { + val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span) + astsForStatement(loweredAssignment) } - private def astForDoWhileStatement(node: DoWhileExpression): Ast = { + private def astForJsonIfStatement(node: IfExpression): Seq[Ast] = { val conditionAst = astForExpression(node.condition) - val bodyAsts = astsForStatement(node.body) - doWhileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + val thenAst = astForThenClause(node.thenClause) + val elseAsts = node.elseClause + .map { + case x: IfExpression => + val wrappedBlock = blockNode(x) + Ast(wrappedBlock).withChildren(astForJsonIfStatement(x)) :: Nil + case x => + astForElseClause(x) :: Nil + } + .getOrElse(Ast() :: Nil) + + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst +: elseAsts) :: Nil } - // `until T do B` is lowered as `while !T do B` - private def astForUntilStatement(node: UntilExpression): Ast = { - val notCondition = astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) - val bodyAsts = astsForStatement(node.body) - whileAst(Some(notCondition), bodyAsts, Option(code(node)), line(node), column(node)) + private def astForAccessModifier(node: AccessModifier): Seq[Ast] = { + scope.surroundingAstLabel match { + case Some(x) if x == NodeTypes.METHOD => + val simpleIdent = node.toSimpleIdentifier + astForSimpleCall(SimpleCall(simpleIdent, List.empty)(simpleIdent.span)) :: Nil + case _ => + registerAccessModifier(node) + } } - private def astForIfStatement(node: IfExpression): Ast = { - def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = { - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) - controlStructureAst(ifNode, Some(conditionAst), thenAst :: elseAsts) + /** Registers the currently set access modifier for the current type (until it is reset later). + */ + private def registerAccessModifier(node: AccessModifier): Seq[Ast] = { + val modifier = node match { + case PrivateModifier() => ModifierTypes.PRIVATE + case ProtectedModifier() => ModifierTypes.PROTECTED + case PublicModifier() => ModifierTypes.PUBLIC } - foldIfExpression(builder)(node) + popAccessModifier() // pop off the current modifier in scope + pushAccessModifier(modifier) // push new one on + Nil } // Rewrites a nested `if T_1 then E_1 elsif T_2 then E_2 elsif ... elsif T_n then E_n else E_{n+1}` @@ -67,11 +107,11 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t builder(node, conditionAst, thenAst, elseAsts) } - private def astForThenClause(node: RubyNode): Ast = astForStatementList(node.asStatementList) + protected def astForThenClause(node: RubyExpression): Ast = astForStatementList(node.asStatementList) private def astsForElseClauses( - elsIfClauses: List[RubyNode], - elseClause: Option[RubyNode], + elsIfClauses: List[RubyExpression], + elseClause: Option[RubyExpression], astForIf: IfExpression => Ast ): List[Ast] = { elsIfClauses match @@ -88,96 +128,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t Nil } - private def astForElseClause(node: RubyNode): Ast = { - node match - case elseNode: ElseClause => - elseNode.thenClause match - case stmtList: StatementList => astForStatementList(stmtList) - case node => - logger.warn(s"Expecting statement list in ${code(node)} ($relativeFileName), skipping") - astForUnknown(node) - case elseNode => - logger.warn(s"Expecting else clause in ${code(elseNode)} ($relativeFileName), skipping") - astForUnknown(elseNode) - } - - // `unless T do B` is lowered as `if !T then B` - private def astForUnlessStatement(node: UnlessExpression): Ast = { - val notConditionAst = astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) - val thenAst = node.trueBranch match - case stmtList: StatementList => astForStatementList(stmtList) - case _ => astForStatementList(StatementList(List(node.trueBranch))(node.trueBranch.span)) - val elseAsts = node.falseBranch.map(astForElseClause).toList - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) - controlStructureAst(ifNode, Some(notConditionAst), thenAst :: elseAsts) - } - - private def astForForExpression(node: ForExpression): Ast = { - val forEachNode = controlStructureNode(node, ControlStructureTypes.FOR, code(node)) - val doBodyAst = astsForStatement(node.doBlock) - val iteratorNode = astForExpression(node.forVariable) - val iterableNode = astForExpression(node.iterableVariable) - Ast(forEachNode).withChild(iteratorNode).withChild(iterableNode).withChildren(doBodyAst) - } - - protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = { - def goCase(expr: Option[SimpleIdentifier]): List[RubyNode] = { - val elseThenClause: Option[RubyNode] = node.elseClause.map(_.asInstanceOf[ElseClause].thenClause) - val whenClauses = node.whenClauses.map(_.asInstanceOf[WhenClause]) - val ifElseChain = whenClauses.foldRight[Option[RubyNode]](elseThenClause) { - (whenClause: WhenClause, restClause: Option[RubyNode]) => - // We translate multiple match expressions into an or expression. - // - // A single match expression is compared using `.===` to the case target expression if it is present - // otherwise it is treated as a conditional. - // - // There may be a splat as the last match expression, - // `case y when *x then c end` or - // `case when *x then c end` - // which is translated to `x.include? y` and `x.any?` conditions respectively - - val conditions = whenClause.matchExpressions.map { mExpr => - expr.map(e => BinaryExpression(mExpr, "===", e)(mExpr.span)).getOrElse(mExpr) - } ++ (whenClause.matchSplatExpression.iterator.flatMap { - case splat @ SplattingRubyNode(exprList) => - expr - .map { e => - List(MemberCall(exprList, ".", "include?", List(e))(splat.span)) - } - .getOrElse { - List(MemberCall(exprList, ".", "any?", List())(splat.span)) - } - case e => - logger.warn(s"Unrecognised RubyNode (${e.getClass}) in case match splat expression") - List(Unknown()(e.span)) - }) - // There is always at least one match expression or a splat - // a splat will become an unknown in condition at the end - val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) => - BinaryExpression(cond, "||", condAcc)(whenClause.span) - } - val conditional = IfExpression( - condition, - whenClause.thenClause.asStatementList, - List(), - restClause.map { els => ElseClause(els.asStatementList)(els.span) } - )(node.span) - Some(conditional) - } - ifElseChain.iterator.toList - } - def generatedNode: StatementList = node.expression - .map { e => - val tmp = SimpleIdentifier(None)(e.span.spanStart(tmpGen.fresh)) - StatementList( - List(SingleAssignment(tmp, "=", e)(e.span)) ++ - goCase(Some(tmp)) - )(node.span) - } - .getOrElse(StatementList(goCase(None))(node.span)) - astsForStatement(generatedNode) - } - protected def astForStatementList(node: StatementList): Ast = { val block = blockNode(node) scope.pushNewScope(BlockScope(block)) @@ -186,56 +136,38 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t blockAst(block, statementAsts) } - /* `foo() do end` is lowered as a METHOD node shaped like so: - * ``` - * = def 0() - * - * end - * foo(, ) - * ``` - */ - protected def astForCallWithBlock[C <: RubyCall](node: RubyNode & RubyCallWithBlock[C]): Ast = { - val Seq(_, methodRefAst) = astForDoBlock(node.block): @unchecked - val methodRefDummyNode = methodRefAst.root.map(DummyNode(_)(node.span)).toList - - // Create call with argument referencing the MethodRef - val callWithLambdaArg = node.withoutBlock match { - case x: SimpleCall => astForSimpleCall(x.copy(arguments = x.arguments ++ methodRefDummyNode)(x.span)) - case x: MemberCall => astForMemberCall(x.copy(arguments = x.arguments ++ methodRefDummyNode)(x.span)) - case x => - logger.warn(s"Unhandled call-with-block type ${code(x)}, creating anonymous method structures only") - Ast() - } - - callWithLambdaArg - } - - protected def astForDoBlock(block: Block & RubyNode): Seq[Ast] = { - // Create closure structures: [MethodDecl, TypeRef, MethodRef] - val methodName = nextClosureName() - - val methodAstsWithRefs = block.body match { - case x: Block => - astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) - case _ => - astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) - } - - // Set span contents - methodAstsWithRefs.flatMap(_.nodes).foreach { - case m: NewMethodRef => DummyNode(m.copy)(block.span.spanStart(m.code)) - case _ => + protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = { + if (closureToRefs.contains(block)) { + closureToRefs(block).map(x => Ast(x.copy)) + } else { + val methodName = nextClosureName() + // Create closure structures: [TypeRef, MethodRef] + val methodRefAsts = block.body match { + case x: Block => + astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + case _ => + astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + } + closureToRefs.put(block, methodRefAsts.flatMap(_.root)) + methodRefAsts } - - methodAstsWithRefs } - protected def astForReturnStatement(node: ReturnExpression): Ast = { + protected def astForReturnExpression(node: ReturnExpression): Ast = { val argumentAsts = node.expressions.map(astForExpression) val returnNode_ = returnNode(node, code(node)) returnAst(returnNode_, argumentAsts) } + protected def astForNextExpression(node: NextExpression): Ast = { + val nextNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.CONTINUE) + .lineNumber(line(node)) + .columnNumber(column(node)) + .code(code(node)) + Ast(nextNode) + } + protected def astForStatementListReturningLastExpression(node: StatementList): Ast = { val block = blockNode(node) scope.pushNewScope(BlockScope(block)) @@ -250,11 +182,11 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t blockAst(block, stmtAsts) } - private def astsForImplicitReturnStatement(node: RubyNode): Seq[Ast] = { + private def astsForImplicitReturnStatement(node: RubyExpression): Seq[Ast] = { def elseReturnNil(span: TextSpan) = Option { ElseClause( StatementList( - ReturnExpression(StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) :: Nil)( + ReturnExpression(StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(span.spanStart("nil")) :: Nil)( span.spanStart("return nil") ) :: Nil )(span.spanStart("return nil")) @@ -262,35 +194,86 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t } node match - case expr: ControlFlowExpression => - def transform(e: RubyNode & ControlFlowExpression): RubyNode = + case expr: ControlFlowStatement => + def transform(e: RubyExpression & ControlFlowStatement): RubyExpression = transformLastRubyNodeInControlFlowExpressionBody(e, returnLastNode(_, transform), elseReturnNil) - astsForStatement(transform(expr)) + + expr match { + case x @ OperatorAssignment(lhs, op, rhs) => + val loweredAssignment = lowerAssignmentOperator(lhs, rhs, op, x.span) + astsForStatement(transform(loweredAssignment)) + case x => + astsForStatement(transform(expr)) + } case node: MemberCallWithBlock => returnAstForRubyCall(node) case node: SimpleCallWithBlock => returnAstForRubyCall(node) case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | SelfIdentifier | IndexAccess | Association | YieldExpr | RubyCall | RubyFieldIdentifier | HereDocNode | Unknown) => - astForReturnStatement(ReturnExpression(List(node))(node.span)) :: Nil + astForReturnExpression(ReturnExpression(List(node))(node.span)) :: Nil case node: SingleAssignment => - astForSingleAssignment(node) :: List(astForReturnStatement(ReturnExpression(List(node.lhs))(node.span))) + astForSingleAssignment(node) :: List(astForReturnExpression(ReturnExpression(List(node.lhs))(node.span))) + case node: DefaultMultipleAssignment => + astsForStatement(node) ++ astsForImplicitReturnStatement(ArrayLiteral(node.assignments.map(_.lhs))(node.span)) + case node: GroupedParameterDesugaring => + // If the desugaring is the last expression, then we should return nil + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(nilReturnSpan) + astsForStatement(node) ++ astsForImplicitReturnStatement(nilReturnLiteral) case node: AttributeAssignment => List( astForAttributeAssignment(node), astForReturnFieldAccess(MemberAccess(node.target, node.op, node.attributeName)(node.span)) ) case node: MemberAccess => astForReturnMemberCall(node) :: Nil - case ret: ReturnExpression => astForReturnStatement(ret) :: Nil - case node: MethodDeclaration => - (astForMethodDeclaration(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList - case _: BreakStatement => astsForStatement(node).toList + case ret: ReturnExpression => astForReturnExpression(ret) :: Nil + case node: (MethodDeclaration | SingletonMethodDeclaration) => + (astsForStatement(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList + case stmtList: StatementList if stmtList.statements.lastOption.exists(_.isInstanceOf[ReturnExpression]) => + stmtList.statements.map(astForExpression) + case StatementList(stmts) => + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(nilReturnSpan) + stmts.map(astForExpression) ++ astsForImplicitReturnStatement(nilReturnLiteral) + case x: RangeExpression => + astForReturnRangeExpression(x) :: Nil + case node: AccessModifier => + val simpleIdent = node.toSimpleIdentifier + val simpleCall = SimpleCall(simpleIdent, List.empty)(simpleIdent.span) + astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil + case node: MethodAccessModifier => + val simpleIdent = node.toSimpleIdentifier + + val methodIdentName = node.method match { + case x: StaticLiteral => x.span.text + case x: MethodDeclaration => x.methodName + case x => + logger.warn(s"Unknown node type for method identifier name: ${x.getClass} (${this.relativeFileName})") + x.span.text + } + + val methodIdent = SimpleIdentifier(None)(simpleIdent.span.spanStart(methodIdentName)) + + val simpleCall = SimpleCall(simpleIdent, List(methodIdent))( + simpleIdent.span.spanStart(s"${simpleIdent.span.text} ${methodIdent.span.text}") + ) + astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil + case node: FieldsDeclaration => + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(nilReturnSpan) + astsForFieldDeclarations(node) ++ astsForImplicitReturnStatement(nilReturnLiteral) + case node: SingletonClassDeclaration => + astForAnonymousTypeDeclaration(node) + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(nilReturnSpan) + astsForImplicitReturnStatement(nilReturnLiteral) case node => logger.warn( - s"Implicit return here not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement" + s" not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement (${this.relativeFileName})" ) astsForStatement(node).toList } - private def returnAstForRubyCall[C <: RubyCall](node: RubyNode & RubyCallWithBlock[C]): Seq[Ast] = { + private def returnAstForRubyCall[C <: RubyCall](node: RubyExpression & RubyCallWithBlock[C]): Seq[Ast] = { val callAst = astForCallWithBlock(node) returnAst(returnNode(node, code(node)), List(callAst)) :: Nil } @@ -301,21 +284,21 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t // The evaluation of a MethodDeclaration returns its name in symbol form. // E.g. `def f = 0` ===> `:f` - private def astForReturnMethodDeclarationSymbolName(node: MethodDeclaration): Ast = { - val literalNode_ = literalNode(node, s":${node.methodName}", getBuiltInType(Defines.Symbol)) + private def astForReturnMethodDeclarationSymbolName(node: RubyExpression & ProcedureDeclaration): Ast = { + val literalNode_ = literalNode(node, s":${node.methodName}", prefixAsCoreType(Defines.Symbol)) val returnNode_ = returnNode(node, literalNode_.code) returnAst(returnNode_, Seq(Ast(literalNode_))) } - private def astForReturnMemberCall(node: MemberAccess): Ast = { - returnAst(returnNode(node, code(node)), List(astForMemberAccess(node))) + private def astForReturnRangeExpression(node: RangeExpression): Ast = { + returnAst(returnNode(node, code(node)), List(astForRange(node))) } - private def astForReturnMemberCall(node: MemberCall): Ast = { - returnAst(returnNode(node, code(node)), List(astForMemberCall(node))) + private def astForReturnMemberCall(node: MemberAccess): Ast = { + returnAst(returnNode(node, code(node)), List(astForMemberAccess(node))) } - protected def astForBreakStatement(node: BreakStatement): Ast = { + protected def astForBreakExpression(node: BreakExpression): Ast = { val _node = NewControlStructure() .controlStructureType(ControlStructureTypes.BREAK) .lineNumber(line(node)) @@ -324,6 +307,10 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t Ast(_node) } + protected def astForSingletonStatementList(list: SingletonStatementList): Seq[Ast] = { + list.statements.map(astForExpression) + } + /** Wraps the last RubyNode with a ReturnExpression. * @param x * the node to wrap a return around. If a StatementList is given, then the ReturnExpression will wrap around the @@ -331,17 +318,20 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t * @return * the RubyNode with an explicit expression */ - private def returnLastNode(x: RubyNode, transform: (RubyNode & ControlFlowExpression) => RubyNode): RubyNode = { - def statementListReturningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match { - case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil - case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil - case (head: ReturnExpression) :: Nil => head :: Nil - case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil - case Nil => List.empty - case head :: tail => head :: statementListReturningLastExpression(tail) + private def returnLastNode( + x: RubyExpression, + transform: (RubyExpression & ControlFlowStatement) => RubyExpression + ): RubyExpression = { + def statementListReturningLastExpression(stmts: List[RubyExpression]): List[RubyExpression] = stmts match { + case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil + case (head: ControlFlowStatement) :: Nil => transform(head) :: Nil + case (head: ReturnExpression) :: Nil => head :: Nil + case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil + case Nil => List.empty + case head :: tail => head :: statementListReturningLastExpression(tail) } - def clauseReturningLastExpression(x: RubyNode & ControlFlowClause): RubyNode = x match { + def clauseReturningLastExpression(x: RubyExpression & ControlFlowClause): RubyExpression = x match { case RescueClause(exceptionClassList, assignment, thenClause) => RescueClause(exceptionClassList, assignment, returnLastNode(thenClause, transform))(x.span) case EnsureClause(thenClause) => EnsureClause(returnLastNode(thenClause, transform))(x.span) @@ -349,15 +339,15 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause, transform))(x.span) case WhenClause(matchExpressions, matchSplatExpression, thenClause) => WhenClause(matchExpressions, matchSplatExpression, returnLastNode(thenClause, transform))(x.span) + case InClause(pattern, body) => InClause(pattern, returnLastNode(body, transform))(x.span) } x match { - case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span) - case clause: ControlFlowClause => clauseReturningLastExpression(clause) - case node: ControlFlowExpression => transform(node) - case node: BreakStatement => node - case node: ReturnExpression => node - case _ => ReturnExpression(x :: Nil)(x.span) + case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span) + case clause: ControlFlowClause => clauseReturningLastExpression(clause) + case node: ControlFlowStatement => transform(node) + case node: ReturnExpression => node + case _ => ReturnExpression(x :: Nil)(x.span) } } @@ -369,10 +359,10 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t * RubyNode with transform function applied */ protected def transformLastRubyNodeInControlFlowExpressionBody( - node: RubyNode & ControlFlowExpression, - transform: RubyNode => RubyNode, + node: RubyExpression & ControlFlowStatement, + transform: RubyExpression => RubyExpression, defaultElseBranch: TextSpan => Option[ElseClause] - ): RubyNode = { + ): RubyExpression = { node match { case RescueExpression(body, rescueClauses, elseClause, ensureClause) => // Ensure never returns a value, only the main body, rescue & else clauses @@ -385,6 +375,9 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span) case DoWhileExpression(condition, body) => DoWhileExpression(condition, transform(body))(node.span) case UntilExpression(condition, body) => UntilExpression(condition, transform(body))(node.span) + case OperatorAssignment(lhs, op, rhs) => + val loweredNode = lowerAssignmentOperator(lhs, rhs, op, node.span) + transformLastRubyNodeInControlFlowExpressionBody(loweredNode, transform, defaultElseBranch) case IfExpression(condition, thenClause, elsifClauses, elseClause) => IfExpression( condition, @@ -406,6 +399,32 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t whenClauses.map(transform), elseClause.map(transform).orElse(defaultElseBranch(node.span)) )(node.span) + case next: NextExpression => next + case break: BreakExpression => break } } + + protected def astForAliasStatement(statement: AliasStatement): Seq[Ast] = { + val aliasMethodDecl = generateAliasMethodDecl(statement) + // alias should always be lifted to the class decl + astForMethodDeclaration(aliasMethodDecl, useSurroundingTypeFullName = true) + } + + private def generateAliasMethodDecl(alias: AliasStatement): MethodDeclaration = { + val span = alias.span + val forwardingCallTarget = SimpleIdentifier(None)(span.spanStart(alias.oldName)) + val forwardedArgs = SplattingRubyNode(SimpleIdentifier()(span.spanStart("args")))(span.spanStart("*args")) + val forwardedBlock = SimpleIdentifier()(span.spanStart("&block")) + val forwardingCall = SimpleCall(forwardingCallTarget, forwardedArgs :: forwardedBlock :: Nil)( + span.spanStart(s"${alias.oldName}(*args, &block)") + ) + + val aliasMethodBody = StatementList(forwardingCall :: Nil)(forwardingCall.span) + val aliasingMethodParams = + ArrayParameter("*args")(span.spanStart("*args")) :: ProcParameter("&block")(span.spanStart("&block")) :: Nil + + MethodDeclaration(alias.newName, aliasingMethodParams, aliasMethodBody)( + alias.span.spanStart(s"def ${alias.newName}(*args, &block)") + ) + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala index 4b827f02db83..9e46d0be3d48 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala @@ -1,13 +1,14 @@ package io.joern.rubysrc2cpg.astcreation -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* -import io.joern.rubysrc2cpg.datastructures.{BlockScope, MethodScope, ModuleScope, TypeScope} +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{TypeDeclaration, *} +import io.joern.rubysrc2cpg.datastructures.{BlockScope, MethodScope, ModuleScope, NamespaceScope, TypeScope} import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.utils.NodeBuilders.newModifierNode import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ DispatchTypes, + EdgeTypes, EvaluationStrategies, ModifierTypes, NodeTypes, @@ -19,7 +20,7 @@ import scala.collection.mutable trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def astForClassDeclaration(node: RubyNode & TypeDeclaration): Seq[Ast] = { + protected def astForClassDeclaration(node: RubyExpression & TypeDeclaration): Seq[Ast] = { node.name match case name: SimpleIdentifier => astForSimpleNamedClassDeclaration(node, name) case name => @@ -27,46 +28,90 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: astForUnknown(node) :: Nil } - private def getBaseClassName(node: RubyNode): Option[String] = { + private def getBaseClassName(node: RubyExpression): String = { node match case simpleIdentifier: SimpleIdentifier => - val name = simpleIdentifier.text - scope.lookupVariable(name) match { - case Some(_) => Option(name) // in the case of singleton classes, we want to keep the variable name - case None => scope.tryResolveTypeReference(name).map(_.name).orElse(Option(name)) - } + simpleIdentifier.text case _: SelfIdentifier => - scope.surroundingTypeFullName + Defines.Self case qualifiedBaseClass: MemberAccess => - scope - .tryResolveTypeReference(qualifiedBaseClass.toString) - .map(_.name) - .orElse(Option(qualifiedBaseClass.toString)) + qualifiedBaseClass.text.replace("::", ".") + case qualifiedBaseClass: MemberCall => + qualifiedBaseClass.text.replace("::", ".") case x => logger.warn( - s"Base class names of type ${x.getClass} are not supported yet: ${code(node)} ($relativeFileName), skipping" + s"Base class names of type ${x.getClass} are not supported yet: ${code(node)} ($relativeFileName), returning string as-is" ) - None + x.text } private def astForSimpleNamedClassDeclaration( - node: RubyNode & TypeDeclaration, + node: RubyExpression & TypeDeclaration, nameIdentifier: SimpleIdentifier ): Seq[Ast] = { - val className = nameIdentifier.text - val inheritsFrom = node.baseClass.flatMap(getBaseClassName).toList - val classFullName = computeClassFullName(className) - val typeDecl = typeDeclNode( - node = node, - name = className, - fullName = classFullName, - filename = relativeFileName, - code = code(node), - inherits = inheritsFrom, - alias = None - ) - scope.surroundingAstLabel.foreach(typeDecl.astParentType(_)) - scope.surroundingScopeFullName.foreach(typeDecl.astParentFullName(_)) + val className = nameIdentifier.text + val inheritsFrom = node.baseClass.map(getBaseClassName).toList + pushAccessModifier(ModifierTypes.PUBLIC) + + /** Pushes new NamespaceScope onto scope stack and populates AST_PARENT_FULL_NAME and AST_PARENT_TYPE for TypeDecls + * that are declared in a namespace + * @param typeDecl + * \- TypeDecl node + * @param astParentFullName + * \- Fullname of AstParent + * @return + * typeDecl node with updated fields + */ + def populateAstParentValues(typeDecl: NewTypeDecl, astParentFullName: String): NewTypeDecl = { + val namespaceBlockFullName = s"${scope.surroundingScopeFullName.getOrElse("")}.$astParentFullName" + scope.pushNewScope(NamespaceScope(namespaceBlockFullName)) + + val namespaceBlock = + NewNamespaceBlock().name(astParentFullName).fullName(astParentFullName).filename(relativeFileName) + + diffGraph.addNode(namespaceBlock) + + fileNode.foreach(diffGraph.addEdge(_, namespaceBlock, EdgeTypes.AST)) + + typeDecl.astParentFullName(astParentFullName) + typeDecl.astParentType(NodeTypes.NAMESPACE_BLOCK) + + typeDecl.fullName(computeFullName(className)) + typeDecl + } + + val (typeDecl, classFullName, shouldPopAdditionalScope) = node match { + case x: NamespaceDeclaration if x.namespaceParts.isDefined => + val className = nameIdentifier.text + val typeDeclTemp = typeDeclNode( + node = node, + name = className, + fullName = Defines.Any, + filename = relativeFileName, + code = code(node), + inherits = inheritsFrom, + alias = None + ) + populateAstParentValues(typeDeclTemp, x.namespaceParts.get.mkString(".")) + val classFullName = typeDeclTemp.fullName + + (typeDeclTemp, classFullName, true) + case _ => + val classFullName = computeFullName(className) + val typeDeclTemp = typeDeclNode( + node = node, + name = className, + fullName = classFullName, + filename = relativeFileName, + code = code(node), + inherits = inheritsFrom, + alias = None + ) + scope.surroundingAstLabel.foreach(typeDeclTemp.astParentType(_)) + scope.surroundingScopeFullName.foreach(typeDeclTemp.astParentFullName(_)) + (typeDeclTemp, classFullName, false) + } + /* In Ruby, there are semantic differences between the ordinary class and singleton class (think "meta" class in Python). Similar to how Java allows both static and dynamic methods/fields/etc. within the same type declaration, @@ -98,17 +143,31 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: val classBody = node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList) - def handleDefaultConstructor(bodyAsts: Seq[Ast]): Seq[Ast] = bodyAsts match { - case bodyAsts if scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST => + val statementsToForwardUpTheAst = mutable.ArrayBuffer.empty[Ast] + def separateStatementsFromBody(ss: List[RubyExpression]) = { + // There may be additional expression nodes introduced from nodes such as type decls, so we must + // re-distribute these back into the method + ss.flatMap { + case t: TypeDeclaration => + val (typeDeclAsts, other) = astsForStatement(t).partition(_.root.exists(_.isInstanceOf[NewTypeDecl])) + statementsToForwardUpTheAst.addAll(other) + typeDeclAsts + case n => astsForStatement(n) + } + } + + val classBodyAsts = { + val bodyAsts = separateStatementsFromBody(classBody.statements) + if (scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST) { val bodyStart = classBody.span.spanStart() val initBody = StatementList(List())(bodyStart) val methodDecl = astForMethodDeclaration(MethodDeclaration(Defines.Initialize, List(), initBody)(bodyStart)) methodDecl ++ bodyAsts - case bodyAsts => bodyAsts + } else { + bodyAsts + } } - val classBodyAsts = handleDefaultConstructor(classBody.statements.flatMap(astsForStatement)) - val fields = node match { case classDecl: ClassDeclaration => classDecl.fields case moduleDecl: ModuleDeclaration => moduleDecl.fields @@ -146,12 +205,23 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: .withChildren(fieldSingletonMemberNodes.map(_._2)) val bodyMemberCallAst = node.bodyMemberCall match { - case Some(bodyMemberCall) => astForMemberCall(bodyMemberCall) + case Some(bodyMemberCall) => astForTypeDeclBodyCall(bodyMemberCall, classFullName) case None => Ast() } (typeDeclAst :: singletonTypeDeclAst :: Nil).foreach(Ast.storeInDiffGraph(_, diffGraph)) - prefixAst :: bodyMemberCallAst :: Nil + + if shouldPopAdditionalScope then scope.popScope() + popAccessModifier() + prefixAst :: bodyMemberCallAst :: statementsToForwardUpTheAst.toList + } + + private def astForTypeDeclBodyCall(node: TypeDeclBodyCall, typeFullName: String): Ast = { + val callAst = astForMemberCall(node.toMemberCall, isStatic = true) + callAst.nodes.collectFirst { + case c: NewCall if c.name == Defines.TypeDeclBody => c.methodFullName(s"$typeFullName.${Defines.TypeDeclBody}") + } + callAst } private def createTypeRefPointer(typeDecl: NewTypeDecl): Ast = { @@ -179,7 +249,11 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: .methodFullName(Operators.fieldAccess) .dispatchType(DispatchTypes.STATIC_DISPATCH) .typeFullName(Defines.Any) - callAst(fieldAccess, Seq(Ast(self), Ast(fi))) + val selfAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(self).withRefEdge(self, selfParam)) + .getOrElse(Ast(self)) + callAst(fieldAccess, Seq(selfAst, Ast(fi))) } astForAssignment(typeRefIdent, typeRefNode, typeDecl.lineNumber, typeDecl.columnNumber) } else { @@ -191,84 +265,58 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: node.fieldNames.flatMap(astsForSingleFieldDeclaration(node, _)) } - private def astsForSingleFieldDeclaration(node: FieldsDeclaration, nameNode: RubyNode): Seq[Ast] = { + private def astsForSingleFieldDeclaration(node: FieldsDeclaration, nameNode: RubyExpression): Seq[Ast] = { nameNode match case nameAsSymbol: StaticLiteral if nameAsSymbol.isSymbol => val fieldName = nameAsSymbol.innerText.prepended('@') val memberNode_ = memberNode(nameAsSymbol, fieldName, code(node), Defines.Any) val memberAst = Ast(memberNode_) - val getterAst = Option.when(node.hasGetter)(astForGetterMethod(node, fieldName)) - val setterAst = Option.when(node.hasSetter)(astForSetterMethod(node, fieldName)) - Seq(memberAst) ++ getterAst.toList ++ setterAst.toList + val getterAst = Option.when(node.hasGetter)(astForGetterMethod(node, fieldName)).getOrElse(Nil) + val setterAst = Option.when(node.hasSetter)(astForSetterMethod(node, fieldName)).getOrElse(Nil) + Seq(memberAst) ++ getterAst ++ setterAst + case nameAsIdent: SimpleIdentifier => + val fieldName = nameAsIdent.span.text.prepended('@') + val memberNode_ = memberNode(nameAsIdent, fieldName, code(node), Defines.Any) + val memberAst = Ast(memberNode_) + val getterAst = Option.when(node.hasGetter)(astForGetterMethod(node, fieldName)).getOrElse(Nil) + val setterAst = Option.when(node.hasSetter)(astForSetterMethod(node, fieldName)).getOrElse(Nil) + Seq(memberAst) ++ getterAst ++ setterAst case _ => - logger.warn(s"Unsupported field declaration: ${nameNode.text}, skipping") + logger.warn( + s"Unsupported field declaration: ${nameNode.text} (${nameNode.getClass}) (${this.relativeFileName}), skipping" + ) Seq() } // creates a `def () { return }` METHOD, for = @. - private def astForGetterMethod(node: FieldsDeclaration, fieldName: String): Ast = { - val name = fieldName.drop(1) - val fullName = computeMethodFullName(name) - val method = methodNode( - node = node, - name = name, - fullName = fullName, - code = s"def $name (...)", - signature = None, - fileName = relativeFileName, - astParentType = scope.surroundingAstLabel, - astParentFullName = scope.surroundingScopeFullName - ) - scope.pushNewScope(MethodScope(fullName, procParamGen.fresh)) - val block_ = blockNode(node) - scope.pushNewScope(BlockScope(block_)) - // TODO: Should it be `return this.@abc`? - val returnAst_ = { - val returnNode_ = returnNode(node, s"return $fieldName") - val fieldNameIdentifier = identifierNode(node, fieldName, fieldName, Defines.Any) - returnAst(returnNode_, Seq(Ast(fieldNameIdentifier))) - } - - val methodBody = blockAst(block_, List(returnAst_)) - scope.popScope() - scope.popScope() - methodAst(method, Seq(), methodBody, methodReturnNode(node, Defines.Any)) + private def astForGetterMethod(node: FieldsDeclaration, fieldName: String): Seq[Ast] = { + val name = fieldName.drop(1) + val code = s"def $name (...)" + val methodDecl = MethodDeclaration( + name, + Nil, + StatementList(InstanceFieldIdentifier()(node.span.spanStart(fieldName)) :: Nil)( + node.span.spanStart(s"return $fieldName") + ) + )(node.span.spanStart(code)) + astForMethodDeclaration(methodDecl, useSurroundingTypeFullName = true) } // creates a `def =(x) { = x }` METHOD, for = @ - private def astForSetterMethod(node: FieldsDeclaration, fieldName: String): Ast = { - val name = fieldName.drop(1) + "=" - val fullName = computeMethodFullName(name) - val method = methodNode( - node = node, - name = name, - fullName = fullName, - code = s"def $name (...)", - signature = None, - fileName = relativeFileName, - astParentType = scope.surroundingAstLabel, - astParentFullName = scope.surroundingScopeFullName - ) - scope.pushNewScope(MethodScope(fullName, procParamGen.fresh)) - val parameter = parameterInNode(node, "x", "x", 1, false, EvaluationStrategies.BY_REFERENCE) - val methodBody = { - val block_ = blockNode(node) - scope.pushNewScope(BlockScope(block_)) - val lhs = identifierNode(node, fieldName, fieldName, Defines.Any) - val rhs = identifierNode(node, parameter.name, parameter.name, Defines.Any) - val assignmentCall = callNode( - node, - s"${lhs.code} = ${rhs.code}", - Operators.assignment, - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val assignmentAst = callAst(assignmentCall, Seq(Ast(lhs), Ast(rhs))) - scope.popScope() - blockAst(blockNode(node), List(assignmentAst)) - } - scope.popScope() - methodAst(method, Seq(Ast(parameter)), methodBody, methodReturnNode(node, Defines.Any)) + private def astForSetterMethod(node: FieldsDeclaration, fieldName: String): Seq[Ast] = { + val name = fieldName.drop(1) + "=" + val code = s"def $name (...)" + val assignment = SingleAssignment( + InstanceFieldIdentifier()(node.span.spanStart(fieldName)), + "=", + SimpleIdentifier()(node.span.spanStart("x")) + )(node.span.spanStart(s"$fieldName = x")) + val methodDecl = MethodDeclaration( + name, + MandatoryParameter("x")(node.span.spanStart("x")) :: Nil, + StatementList(assignment :: Nil)(node.span.spanStart(s"return $fieldName")) + )(node.span.spanStart(code)) + astForMethodDeclaration(methodDecl, useSurroundingTypeFullName = true) } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstSummaryVisitor.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstSummaryVisitor.scala index a1aceea6f561..7486bf43dac6 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstSummaryVisitor.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstSummaryVisitor.scala @@ -1,9 +1,8 @@ package io.joern.rubysrc2cpg.astcreation -import better.files.File -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.StatementList +import flatgraph.DiffGraphApplier +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{RubyExpression, StatementList} import io.joern.rubysrc2cpg.datastructures.{RubyField, RubyMethod, RubyProgramSummary, RubyStubbedType, RubyType} -import io.joern.rubysrc2cpg.parser.RubyNodeCreator import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.layers.Base import io.joern.x2cpg.passes.base.{AstLinkerPass, FileCreationPass} @@ -12,7 +11,6 @@ import io.shiftleft.codepropertygraph.cpgloading.CpgLoader import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Local, Member, Method, TypeDecl} import io.shiftleft.semanticcpg.language.* -import overflowdb.{BatchedUpdate, Config} import java.io.File as JavaFile import java.util.regex.Matcher @@ -23,13 +21,13 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A def summarize(asExternal: Boolean = false): RubyProgramSummary = { this.parseLevel = AstParseLevel.SIGNATURES + Using.resource(Cpg.empty) { cpg => // Build and store compilation unit AST - val rootNode = new RubyNodeCreator().visit(programCtx).asInstanceOf[StatementList] - val ast = astForRubyFile(rootNode) + val ast = astForRubyFile(rootNode) Ast.storeInDiffGraph(ast, diffGraph) - BatchedUpdate.applyDiff(cpg.graph, diffGraph) - CpgLoader.createIndexes(cpg) + DiffGraphApplier.applyDiff(cpg.graph, diffGraph) + // Link basic AST elements AstLinkerPass(cpg).createAndApply() // Summarize findings @@ -38,7 +36,7 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A } def withSummary(newSummary: RubyProgramSummary): AstCreator = { - AstCreator(fileName, programCtx, projectRoot, newSummary) + AstCreator(fileName, projectRoot, newSummary, enableFileContents, fileContent, rootNode) } private def summarize(cpg: Cpg, asExternal: Boolean): RubyProgramSummary = { @@ -116,7 +114,7 @@ trait AstSummaryVisitor(implicit withSchemaValidation: ValidationMode) { this: A }.toSet // Map module types val typeEntries = namespace.method.collectFirst { - case m: Method if m.name == Defines.Program => + case m: Method if m.name == Defines.Main => val childrenTypes = m.astChildren.collectAll[TypeDecl].l val fullName = if childrenTypes.nonEmpty && asExternal then buildFullName(childrenTypes.head) else s"${m.fullName}" diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 01a57a1f6414..97cfeb272291 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -1,9 +1,12 @@ package io.joern.rubysrc2cpg.astcreation +import io.joern.rubysrc2cpg.passes.Defines.RubyOperators import io.joern.rubysrc2cpg.passes.{Defines, GlobalTypes} +import io.joern.x2cpg.Ast import io.shiftleft.codepropertygraph.generated.nodes.NewNode +import org.slf4j.LoggerFactory -import scala.annotation.tailrec +import java.util.Objects object RubyIntermediateAst { @@ -12,12 +15,15 @@ object RubyIntermediateAst { column: Option[Int], lineEnd: Option[Int], columnEnd: Option[Int], + offset: Option[(Int, Int)], text: String ) { - def spanStart(newText: String = ""): TextSpan = TextSpan(line, column, line, column, newText) + def spanStart(newText: String = ""): TextSpan = TextSpan(line, column, line, column, offset, newText) } - sealed class RubyNode(val span: TextSpan) { + /** Most-if-not-all constructs in Ruby evaluate to some value, so we name the base class `RubyExpression`. + */ + sealed class RubyExpression(val span: TextSpan) { def line: Option[Int] = span.line def column: Option[Int] = span.column @@ -26,19 +32,47 @@ object RubyIntermediateAst { def columnEnd: Option[Int] = span.columnEnd + def offset: Option[(Int, Int)] = span.offset + def text: String = span.text + + override def hashCode(): Int = Objects.hash(span) + + override def equals(obj: Any): Boolean = { + obj match { + case o: RubyExpression => o.span == span + case _ => false + } + } } - implicit class RubyNodeHelper(node: RubyNode) { + /** Ruby statements evaluate to some value (and thus are expressions), but also perform some operation, e.g., + * assignments, method definitions, etc. + */ + sealed trait RubyStatement extends RubyExpression + + implicit class RubyExpressionHelper(node: RubyExpression) { def asStatementList: StatementList = node match { case stmtList: StatementList => stmtList case _ => StatementList(List(node))(node.span) } } - final case class Unknown()(span: TextSpan) extends RubyNode(span) + final case class Unknown()(span: TextSpan) extends RubyExpression(span) - final case class StatementList(statements: List[RubyNode])(span: TextSpan) extends RubyNode(span) { + final case class StatementList(statements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement { + override def text: String = statements.size match + case 0 | 1 => span.text + case _ => "(...)" + + def size: Int = statements.size + } + + final case class SingletonStatementList(statements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement { override def text: String = statements.size match case 0 | 1 => span.text case _ => "(...)" @@ -48,106 +82,165 @@ object RubyIntermediateAst { sealed trait AllowedTypeDeclarationChild - sealed trait TypeDeclaration extends AllowedTypeDeclarationChild { - def name: RubyNode - def baseClass: Option[RubyNode] - def body: RubyNode - def bodyMemberCall: Option[MemberCall] + sealed trait TypeDeclaration extends AllowedTypeDeclarationChild with RubyStatement { + def name: RubyExpression + def baseClass: Option[RubyExpression] + def body: RubyExpression + def bodyMemberCall: Option[TypeDeclBodyCall] + } + + sealed trait NamespaceDeclaration extends RubyStatement { + def namespaceParts: Option[List[String]] } final case class ModuleDeclaration( - name: RubyNode, - body: RubyNode, - fields: List[RubyNode & RubyFieldIdentifier], - bodyMemberCall: Option[MemberCall] + name: RubyExpression, + body: RubyExpression, + fields: List[RubyExpression & RubyFieldIdentifier], + bodyMemberCall: Option[TypeDeclBodyCall], + namespaceParts: Option[List[String]] )(span: TextSpan) - extends RubyNode(span) - with TypeDeclaration { - def baseClass: Option[RubyNode] = None + extends RubyExpression(span) + with TypeDeclaration + with NamespaceDeclaration { + def baseClass: Option[RubyExpression] = None } final case class ClassDeclaration( - name: RubyNode, - baseClass: Option[RubyNode], - body: RubyNode, - fields: List[RubyNode & RubyFieldIdentifier], - bodyMemberCall: Option[MemberCall] + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + fields: List[RubyExpression & RubyFieldIdentifier], + bodyMemberCall: Option[TypeDeclBodyCall], + namespaceParts: Option[List[String]] )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with TypeDeclaration + with NamespaceDeclaration - sealed trait AnonymousTypeDeclaration extends RubyNode with TypeDeclaration + sealed trait AnonymousTypeDeclaration extends RubyExpression with TypeDeclaration final case class AnonymousClassDeclaration( - name: RubyNode, - baseClass: Option[RubyNode], - body: RubyNode, - bodyMemberCall: Option[MemberCall] = None + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + bodyMemberCall: Option[TypeDeclBodyCall] = None )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with AnonymousTypeDeclaration final case class SingletonClassDeclaration( - name: RubyNode, - baseClass: Option[RubyNode], - body: RubyNode, - bodyMemberCall: Option[MemberCall] = None + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + bodyMemberCall: Option[TypeDeclBodyCall] = None )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with AnonymousTypeDeclaration - final case class FieldsDeclaration(fieldNames: List[RubyNode])(span: TextSpan) - extends RubyNode(span) + final case class FieldsDeclaration(fieldNames: List[RubyExpression], accessType: String)(span: TextSpan) + extends RubyExpression(span) with AllowedTypeDeclarationChild { def hasGetter: Boolean = text.startsWith("attr_reader") || text.startsWith("attr_accessor") - def hasSetter: Boolean = text.startsWith("attr_writer") || text.startsWith("attr_accessor") + + def isSplattingFieldDecl: Boolean = fieldNames.length == 1 && fieldNames.head.isInstanceOf[SplattingRubyNode] + } + + sealed trait ProcedureDeclaration extends RubyStatement { + def methodName: String + def parameters: List[RubyExpression] + def body: RubyExpression } - final case class MethodDeclaration(methodName: String, parameters: List[RubyNode], body: RubyNode)(span: TextSpan) - extends RubyNode(span) + final case class MethodDeclaration(methodName: String, parameters: List[RubyExpression], body: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with ProcedureDeclaration with AllowedTypeDeclarationChild final case class SingletonMethodDeclaration( - target: RubyNode, + target: RubyExpression, methodName: String, - parameters: List[RubyNode], - body: RubyNode + parameters: List[RubyExpression], + body: RubyExpression )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) + with ProcedureDeclaration with AllowedTypeDeclarationChild + final case class SingletonObjectMethodDeclaration( + methodName: String, + parameters: List[RubyExpression], + body: RubyExpression, + baseClass: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ProcedureDeclaration + sealed trait MethodParameter { def name: String } - final case class MandatoryParameter(name: String)(span: TextSpan) extends RubyNode(span) with MethodParameter + final case class MandatoryParameter(name: String)(span: TextSpan) extends RubyExpression(span) with MethodParameter { + def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span) + } - final case class OptionalParameter(name: String, defaultExpression: RubyNode)(span: TextSpan) - extends RubyNode(span) + final case class OptionalParameter(name: String, defaultExpression: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with MethodParameter + + final case class GroupedParameter( + name: String, + tmpParam: RubyExpression, + multipleAssignment: GroupedParameterDesugaring + )(span: TextSpan) + extends RubyExpression(span) with MethodParameter sealed trait CollectionParameter extends MethodParameter - final case class ArrayParameter(name: String)(span: TextSpan) extends RubyNode(span) with CollectionParameter + final case class ArrayParameter(name: String)(span: TextSpan) extends RubyExpression(span) with CollectionParameter - final case class HashParameter(name: String)(span: TextSpan) extends RubyNode(span) with CollectionParameter + final case class HashParameter(name: String)(span: TextSpan) extends RubyExpression(span) with CollectionParameter - final case class ProcParameter(name: String)(span: TextSpan) extends RubyNode(span) with MethodParameter + final case class ProcParameter(name: String)(span: TextSpan) extends RubyExpression(span) with MethodParameter - final case class SingleAssignment(lhs: RubyNode, op: String, rhs: RubyNode)(span: TextSpan) extends RubyNode(span) + final case class SingleAssignment(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with RubyStatement - final case class MultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan) extends RubyNode(span) + trait MultipleAssignment extends RubyStatement { + def assignments: List[SingleAssignment] + } - final case class SplattingRubyNode(name: RubyNode)(span: TextSpan) extends RubyNode(span) + final case class OperatorAssignment(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with RubyStatement + with ControlFlowStatement - final case class AttributeAssignment(target: RubyNode, op: String, attributeName: String, rhs: RubyNode)( - span: TextSpan - ) extends RubyNode(span) + final case class DefaultMultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan) + extends RubyExpression(span) + with MultipleAssignment - /** Any structure that conditionally modifies the control flow of the program. + final case class GroupedParameterDesugaring(assignments: List[SingleAssignment])(span: TextSpan) + extends RubyExpression(span) + with MultipleAssignment + + final case class SplattingRubyNode(target: RubyExpression)(span: TextSpan) extends RubyExpression(span) + + final case class AttributeAssignment( + target: RubyExpression, + op: String, + attributeName: String, + assignmentOperator: String, + rhs: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + + /** Any structure that conditionally modifies the control flow of the program. These also behave as statements. */ - sealed trait ControlFlowExpression + sealed trait ControlFlowStatement extends RubyStatement /** A control structure's clause, which may contain an additional control structures. */ @@ -155,90 +248,124 @@ object RubyIntermediateAst { /** Any structure that is an Identifier, except self. e.g. `a`, `@a`, `@@a` */ - sealed trait RubyIdentifier + sealed trait RubyIdentifier extends RubyExpression { + override def toString: String = span.text + } /** Ruby Instance or Class Variable Identifiers: `@a`, `@@a` */ - sealed trait RubyFieldIdentifier extends RubyIdentifier + sealed trait RubyFieldIdentifier extends RubyIdentifier { + def toMemberAccess: MemberAccess = { + MemberAccess(SelfIdentifier()(span), ".", span.text)(span) + } + } sealed trait SingletonMethodIdentifier final case class RescueExpression( - body: RubyNode, + body: RubyExpression, rescueClauses: List[RescueClause], elseClause: Option[ElseClause], ensureClause: Option[EnsureClause] )(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + extends RubyExpression(span) + with ControlFlowStatement final case class RescueClause( - exceptionClassList: Option[RubyNode], - variables: Option[RubyNode], - thenClause: RubyNode + exceptionClassList: Option[RubyExpression], + variables: Option[RubyExpression], + thenClause: RubyExpression )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with ControlFlowClause - final case class EnsureClause(thenClause: RubyNode)(span: TextSpan) extends RubyNode(span) with ControlFlowClause + final case class EnsureClause(thenClause: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause - final case class WhileExpression(condition: RubyNode, body: RubyNode)(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + final case class WhileExpression(condition: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement - final case class DoWhileExpression(condition: RubyNode, body: RubyNode)(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + final case class DoWhileExpression(condition: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement - final case class UntilExpression(condition: RubyNode, body: RubyNode)(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + final case class UntilExpression(condition: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement final case class IfExpression( - condition: RubyNode, - thenClause: RubyNode, - elsifClauses: List[RubyNode], - elseClause: Option[RubyNode] + condition: RubyExpression, + thenClause: RubyExpression, + elsifClauses: List[RubyExpression] = Nil, + elseClause: Option[RubyExpression] = None )(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + extends RubyExpression(span) + with ControlFlowStatement + with RubyStatement - final case class ElsIfClause(condition: RubyNode, thenClause: RubyNode)(span: TextSpan) - extends RubyNode(span) + final case class ElsIfClause(condition: RubyExpression, thenClause: RubyExpression)(span: TextSpan) + extends RubyExpression(span) with ControlFlowClause - final case class ElseClause(thenClause: RubyNode)(span: TextSpan) extends RubyNode(span) with ControlFlowClause + final case class ElseClause(thenClause: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause - final case class UnlessExpression(condition: RubyNode, trueBranch: RubyNode, falseBranch: Option[RubyNode])( - span: TextSpan - ) extends RubyNode(span) - with ControlFlowExpression + final case class UnlessExpression( + condition: RubyExpression, + trueBranch: RubyExpression, + falseBranch: Option[RubyExpression] + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement - final case class ForExpression(forVariable: RubyNode, iterableVariable: RubyNode, doBlock: RubyNode)(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + final case class ForExpression( + forVariable: RubyExpression, + iterableVariable: RubyExpression, + doBlock: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement final case class CaseExpression( - expression: Option[RubyNode], - whenClauses: List[RubyNode], - elseClause: Option[RubyNode] + expression: Option[RubyExpression], + matchClauses: List[RubyExpression], + elseClause: Option[RubyExpression] )(span: TextSpan) - extends RubyNode(span) - with ControlFlowExpression + extends RubyExpression(span) + with ControlFlowStatement final case class WhenClause( - matchExpressions: List[RubyNode], - matchSplatExpression: Option[RubyNode], - thenClause: RubyNode + matchExpressions: List[RubyExpression], + matchSplatExpression: Option[RubyExpression], + thenClause: RubyExpression )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) + with ControlFlowClause + + final case class InClause(pattern: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) with ControlFlowClause - final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span) + final case class ArrayPattern(children: List[RubyExpression])(span: TextSpan) extends RubyExpression(span) + + final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) { + def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span) + } + + final case class NextExpression()(span: TextSpan) extends RubyExpression(span) with ControlFlowStatement + + final case class BreakExpression()(span: TextSpan) extends RubyExpression(span) with ControlFlowStatement + + final case class ReturnExpression(expressions: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement /** Represents an unqualified identifier e.g. `X`, `x`, `@@x`, `$x`, `$<`, etc. */ final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with RubyIdentifier with SingletonMethodIdentifier { override def toString: String = s"SimpleIdentifier(${span.text}, $typeFullName)" @@ -246,18 +373,20 @@ object RubyIntermediateAst { /** Represents a type reference successfully determined, e.g. module A; end; A */ - final case class TypeIdentifier(typeFullName: String)(span: TextSpan) extends RubyNode(span) with RubyIdentifier { - def isBuiltin: Boolean = typeFullName.startsWith(GlobalTypes.builtinPrefix) + final case class TypeIdentifier(typeFullName: String)(span: TextSpan) + extends RubyExpression(span) + with RubyIdentifier { + def isBuiltin: Boolean = typeFullName.startsWith(GlobalTypes.corePrefix) override def toString: String = s"TypeIdentifier(${span.text}, $typeFullName)" } /** Represents a InstanceFieldIdentifier e.g `@x` */ - final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyFieldIdentifier + final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyExpression(span) with RubyFieldIdentifier /** Represents a ClassFieldIdentifier e.g `@@x` */ - final case class ClassFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyFieldIdentifier + final case class ClassFieldIdentifier()(span: TextSpan) extends RubyExpression(span) with RubyFieldIdentifier - final case class SelfIdentifier()(span: TextSpan) extends RubyNode(span) with SingletonMethodIdentifier + final case class SelfIdentifier()(span: TextSpan) extends RubyExpression(span) with SingletonMethodIdentifier /** Represents some kind of literal expression. */ @@ -266,7 +395,7 @@ object RubyIntermediateAst { } /** Represents a non-interpolated literal. */ - final case class StaticLiteral(typeFullName: String)(span: TextSpan) extends RubyNode(span) with LiteralExpr { + final case class StaticLiteral(typeFullName: String)(span: TextSpan) extends RubyExpression(span) with LiteralExpr { def isSymbol: Boolean = text.startsWith(":") def isString: Boolean = text.startsWith("\"") || text.startsWith("'") @@ -280,19 +409,43 @@ object RubyIntermediateAst { case s => s } } + } - final case class DynamicLiteral(typeFullName: String, expressions: List[RubyNode])(span: TextSpan) - extends RubyNode(span) + object StaticLiteral { + + private val logger = LoggerFactory.getLogger(getClass) + + def unapply(literal: StaticLiteral): Option[String] = { + val typeName = literal.typeFullName.stripPrefix(s"${GlobalTypes.corePrefix}.") + Some(typeName).filter(GlobalTypes.bundledClasses.contains) match { + case None => + logger.warn( + s"Unapply called on static literal with type not contained within known bundled classes: ${literal.typeFullName}" + ) + Some(literal.typeFullName) + case x => x + } + + } + } + + final case class DynamicLiteral(typeFullName: String, expressions: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) with LiteralExpr - final case class RangeExpression(lowerBound: RubyNode, upperBound: RubyNode, rangeOperator: RangeOperator)( - span: TextSpan - ) extends RubyNode(span) + final case class RangeExpression( + lowerBound: RubyExpression, + upperBound: RubyExpression, + rangeOperator: RangeOperator + )(span: TextSpan) + extends RubyExpression(span) - final case class RangeOperator(exclusive: Boolean)(span: TextSpan) extends RubyNode(span) + final case class RangeOperator(exclusive: Boolean)(span: TextSpan) extends RubyExpression(span) - final case class ArrayLiteral(elements: List[RubyNode])(span: TextSpan) extends RubyNode(span) with LiteralExpr { + final case class ArrayLiteral(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with LiteralExpr { def isSymbolArray: Boolean = text.take(2).toLowerCase.startsWith("%i") def isStringArray: Boolean = text.take(2).toLowerCase.startsWith("%w") @@ -301,51 +454,98 @@ object RubyIntermediateAst { def isStatic: Boolean = !isDynamic - def typeFullName: String = Defines.getBuiltInType(Defines.Array) + def typeFullName: String = Defines.prefixAsCoreType(Defines.Array) } - final case class HashLiteral(elements: List[RubyNode])(span: TextSpan) extends RubyNode(span) with LiteralExpr { - def typeFullName: String = Defines.getBuiltInType(Defines.Hash) + sealed trait HashLike extends RubyExpression with LiteralExpr { + def elements: List[RubyExpression] + def typeFullName: String = Defines.prefixAsCoreType(Defines.Hash) } - final case class Association(key: RubyNode, value: RubyNode)(span: TextSpan) extends RubyNode(span) + final case class HashLiteral(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with HashLike + + final case class Association(key: RubyExpression, value: RubyExpression)(span: TextSpan) extends RubyExpression(span) + + final case class AssociationList(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with HashLike /** Represents a call. */ - sealed trait RubyCall { - def target: RubyNode - def arguments: List[RubyNode] + sealed trait RubyCall extends RubyExpression { + def target: RubyExpression + def arguments: List[RubyExpression] + def withBlock(block: Block): RubyCallWithBlock[?] = SimpleCallWithBlock(target, arguments, block)(span) } /** Represents traditional calls, e.g. `foo`, `foo x, y`, `foo(x,y)` */ - final case class SimpleCall(target: RubyNode, arguments: List[RubyNode])(span: TextSpan) - extends RubyNode(span) + final case class SimpleCall(target: RubyExpression, arguments: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) with RubyCall final case class RequireCall( - target: RubyNode, - argument: RubyNode, + target: RubyExpression, + argument: RubyExpression, isRelative: Boolean = false, isWildCard: Boolean = false )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with RubyCall { - def arguments: List[RubyNode] = List(argument) - def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) + def arguments: List[RubyExpression] = List(argument) + def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) } - final case class IncludeCall(target: RubyNode, argument: RubyNode)(span: TextSpan) - extends RubyNode(span) + final case class IncludeCall(target: RubyExpression, argument: RubyExpression)(span: TextSpan) + extends RubyExpression(span) with RubyCall { - def arguments: List[RubyNode] = List(argument) - def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) + def arguments: List[RubyExpression] = List(argument) + def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) + } + + final case class RaiseCall(target: RubyExpression, arguments: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyCall + + sealed trait AccessModifier extends AllowedTypeDeclarationChild { + def toSimpleIdentifier: SimpleIdentifier + } + + sealed trait MethodAccessModifier extends AllowedTypeDeclarationChild { + def toSimpleIdentifier: SimpleIdentifier + def method: RubyExpression + } + + final case class PublicModifier()(span: TextSpan) extends RubyExpression(span) with AccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + } + + final case class PrivateModifier()(span: TextSpan) extends RubyExpression(span) with AccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + } + + final case class ProtectedModifier()(span: TextSpan) extends RubyExpression(span) with AccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + } + + final case class PrivateMethodModifier(method: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with MethodAccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span.spanStart("private_class_method")) + } + + final case class PublicMethodModifier(method: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with MethodAccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span.spanStart("public_class_method")) } /** Represents standalone `proc { ... }` or `lambda { ... }` expressions */ - final case class ProcOrLambdaExpr(block: Block)(span: TextSpan) extends RubyNode(span) + final case class ProcOrLambdaExpr(block: Block)(span: TextSpan) extends RubyExpression(span) - final case class YieldExpr(arguments: List[RubyNode])(span: TextSpan) extends RubyNode(span) + final case class YieldExpr(arguments: List[RubyExpression])(span: TextSpan) extends RubyExpression(span) /** Represents a call with a block argument. */ @@ -353,78 +553,112 @@ object RubyIntermediateAst { def block: Block - def withoutBlock: RubyNode & C + def withoutBlock: RubyExpression & C } - final case class SimpleCallWithBlock(target: RubyNode, arguments: List[RubyNode], block: Block)(span: TextSpan) - extends RubyNode(span) + final case class SimpleCallWithBlock(target: RubyExpression, arguments: List[RubyExpression], block: Block)( + span: TextSpan + ) extends RubyExpression(span) with RubyCallWithBlock[SimpleCall] { def withoutBlock: SimpleCall = SimpleCall(target, arguments)(span) } /** Represents member calls, e.g. `x.y(z,w)` */ - final case class MemberCall(target: RubyNode, op: String, methodName: String, arguments: List[RubyNode])( + final case class MemberCall(target: RubyExpression, op: String, methodName: String, arguments: List[RubyExpression])( span: TextSpan - ) extends RubyNode(span) - with RubyCall + ) extends RubyExpression(span) + with RubyCall { + + def isRegexMatch: Boolean = methodName == RubyOperators.regexpMatch + + override def withBlock(block: Block): RubyCallWithBlock[?] = + MemberCallWithBlock(target, op, methodName, arguments, block)(span) + } + + /** Special class for `` calls of type decls. + */ + final case class TypeDeclBodyCall(target: RubyExpression, typeName: String)(span: TextSpan) + extends RubyExpression(span) + with RubyCall { + + def toMemberCall: MemberCall = MemberCall(target, op, Defines.TypeDeclBody, arguments)(span) + + def arguments: List[RubyExpression] = Nil + + def op: String = "::" + } final case class MemberCallWithBlock( - target: RubyNode, + target: RubyExpression, op: String, methodName: String, - arguments: List[RubyNode], + arguments: List[RubyExpression], block: Block )(span: TextSpan) - extends RubyNode(span) + extends RubyExpression(span) with RubyCallWithBlock[MemberCall] { def withoutBlock: MemberCall = MemberCall(target, op, methodName, arguments)(span) } /** Represents index accesses, e.g. `x[0]`, `self.x.y[1, 2]` */ - final case class IndexAccess(target: RubyNode, indices: List[RubyNode])(span: TextSpan) extends RubyNode(span) + final case class IndexAccess(target: RubyExpression, indices: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) - final case class MemberAccess(target: RubyNode, op: String, memberName: String)(span: TextSpan) - extends RubyNode(span) { - override def toString: String = s"${target.text}.$memberName" + final case class MemberAccess(target: RubyExpression, op: String, memberName: String)(span: TextSpan) + extends RubyExpression(span) { + override def toString: String = s"${target.text}${op}$memberName" } /** A Ruby node that instantiates objects. */ sealed trait ObjectInstantiation extends RubyCall - final case class SimpleObjectInstantiation(target: RubyNode, arguments: List[RubyNode])(span: TextSpan) - extends RubyNode(span) - with ObjectInstantiation + final case class SimpleObjectInstantiation(target: RubyExpression, arguments: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with ObjectInstantiation { + override def withBlock(block: Block): RubyCallWithBlock[SimpleObjectInstantiation] = + ObjectInstantiationWithBlock(target, arguments, block)(span) + } - final case class ObjectInstantiationWithBlock(target: RubyNode, arguments: List[RubyNode], block: Block)( + final case class ObjectInstantiationWithBlock(target: RubyExpression, arguments: List[RubyExpression], block: Block)( span: TextSpan - ) extends RubyNode(span) + ) extends RubyExpression(span) with ObjectInstantiation with RubyCallWithBlock[SimpleObjectInstantiation] { def withoutBlock: SimpleObjectInstantiation = SimpleObjectInstantiation(target, arguments)(span) } /** Represents a `do` or `{ .. }` (braces) block. */ - final case class Block(parameters: List[RubyNode], body: RubyNode)(span: TextSpan) extends RubyNode(span) { + final case class Block(parameters: List[RubyExpression], body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with RubyStatement { - def toMethodDeclaration(name: String, parameters: Option[List[RubyNode]]): MethodDeclaration = parameters match { - case Some(givenParameters) => MethodDeclaration(name, givenParameters, body)(span) - case None => MethodDeclaration(name, this.parameters, body)(span) - } + def toStatementList: StatementList = StatementList(body :: Nil)(span) + def toMethodDeclaration(name: String, parameters: Option[List[RubyExpression]]): MethodDeclaration = + parameters match { + case Some(givenParameters) => MethodDeclaration(name, givenParameters, body)(span) + case None => MethodDeclaration(name, this.parameters, body)(span) + } } /** A dummy class for wrapping around `NewNode` and allowing it to integrate with RubyNode classes. */ - final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyNode(span) + final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyExpression(span) - final case class UnaryExpression(op: String, expression: RubyNode)(span: TextSpan) extends RubyNode(span) + /** A dummy class for wrapping around `Ast` and allowing it to integrate with RubyNode classes. + */ + final case class DummyAst(ast: Ast)(span: TextSpan) extends RubyExpression(span) + + final case class UnaryExpression(op: String, expression: RubyExpression)(span: TextSpan) extends RubyExpression(span) - final case class BinaryExpression(lhs: RubyNode, op: String, rhs: RubyNode)(span: TextSpan) extends RubyNode(span) + final case class BinaryExpression(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan) + extends RubyExpression(span) - final case class HereDocNode(content: String)(span: TextSpan) extends RubyNode(span) + final case class HereDocNode(content: String)(span: TextSpan) extends RubyExpression(span) - final case class AliasStatement(oldName: String, newName: String)(span: TextSpan) extends RubyNode(span) + final case class AliasStatement(oldName: String, newName: String)(span: TextSpan) + extends RubyExpression(span) + with AllowedTypeDeclarationChild - final case class BreakStatement()(span: TextSpan) extends RubyNode(span) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala index 2b98bce044aa..745f5148d668 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala @@ -1,17 +1,17 @@ package io.joern.rubysrc2cpg.datastructures import better.files.File -import io.joern.rubysrc2cpg.passes.GlobalTypes -import io.joern.rubysrc2cpg.passes.GlobalTypes.builtinPrefix +import io.joern.rubysrc2cpg.passes.{GlobalTypes, Defines as RubyDefines} import io.joern.x2cpg.Defines -import io.joern.rubysrc2cpg.passes.Defines as RDefines import io.joern.x2cpg.datastructures.* import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.{DeclarationNew, NewLocal, NewMethodParameterIn} +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import java.io.File as JFile import scala.collection.mutable import scala.reflect.ClassTag +import scala.collection.mutable import scala.util.Try class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) @@ -26,17 +26,30 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) mutable.Set(RubyType(GlobalTypes.kernelPrefix, builtinMethods, List.empty)) // Add some built-in methods that are significant - // TODO: Perhaps create an offline pre-built list of methods typesInScope.addAll( Seq( RubyType( - s"$builtinPrefix.Array", - List(RubyMethod("[]", List.empty, s"$builtinPrefix.Array", Option(s"$builtinPrefix.Array"))), + RubyDefines.prefixAsCoreType(RubyDefines.Array), + List( + RubyMethod( + "[]", + List.empty, + RubyDefines.prefixAsCoreType(RubyDefines.Array), + Option(RubyDefines.prefixAsCoreType(RubyDefines.Array)) + ) + ), List.empty ), RubyType( - s"$builtinPrefix.Hash", - List(RubyMethod("[]", List.empty, s"$builtinPrefix.Hash", Option(s"$builtinPrefix.Hash"))), + RubyDefines.prefixAsCoreType(RubyDefines.Hash), + List( + RubyMethod( + "[]", + List.empty, + RubyDefines.prefixAsCoreType(RubyDefines.Hash), + Option(RubyDefines.prefixAsCoreType(RubyDefines.Hash)) + ) + ), List.empty ) ) @@ -47,7 +60,8 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) /** @return * using the stack, will initialize a new module scope object. */ - def newProgramScope: Option[ProgramScope] = surroundingScopeFullName.map(ProgramScope.apply) + def newProgramScope: Option[ProgramScope] = + surroundingScopeFullName.map(_.stripSuffix(NamespaceTraversal.globalNamespaceName)).map(ProgramScope.apply) /** @return * true if the top of the stack is the program/module. @@ -124,6 +138,13 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) } } + def lookupVariableInOuterScope(identifier: String): List[DeclarationNew] = { + stack.drop(1).collect { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement.variables(identifier) + } + } + def addRequire( projectRoot: String, currentFilePath: String, @@ -222,17 +243,40 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) def useProcParam: Option[String] = updateSurrounding { case ScopeElement(MethodScope(fullName, param, _), variables) => (ScopeElement(MethodScope(fullName, param, true), variables), param.fold(x => x, x => x)) + case ScopeElement(ConstructorScope(fullName, param, _), variables) => + (ScopeElement(ConstructorScope(fullName, param, true), variables), param.fold(x => x, x => x)) } - /** Get the name of the implicit or explict proc param */ - def anonProcParam: Option[String] = stack.collectFirst { case ScopeElement(MethodScope(_, Left(param), true), _) => - param + /** Get the name of the implicit or explicit proc param */ + def anonProcParam: Option[String] = stack.collectFirst { + case ScopeElement(x: MethodLikeScope, _) if x.procParam.isLeft => + x.procParam match { + case Left(param) => param + case Right(param) => + param // this is just so that we don't get a pattern match warning, but should never be triggered + } } - /** Set the name of explict proc param */ - def setProcParam(param: String): Unit = updateSurrounding { + /** Set the name of explicit proc param */ + def setProcParam(param: String, paramNode: NewMethodParameterIn): Unit = updateSurrounding { case ScopeElement(MethodScope(fullName, _, _), variables) => - (ScopeElement(MethodScope(fullName, Right(param)), variables), ()) + (ScopeElement(MethodScope(fullName, Right(param), true), variables ++ Map(paramNode.name -> paramNode)), ()) + case ScopeElement(ConstructorScope(fullName, _, _), variables) => + (ScopeElement(ConstructorScope(fullName, Right(param), true), variables ++ Map(paramNode.name -> paramNode)), ()) + } + + /** If a proc param is used, provides the node to add to the AST. + */ + def procParamName: Option[NewMethodParameterIn] = { + stack + .collectFirst { + case ScopeElement(x: MethodLikeScope, _) if x.hasYield => + x.procParam match { + case Left(param) => param + case Right(param) => param + } + } + .flatMap(lookupVariable(_).collect { case p: NewMethodParameterIn => p }) } def surroundingTypeFullName: Option[String] = stack.collectFirst { case ScopeElement(x: TypeLikeScope, _) => @@ -330,9 +374,9 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) .orElse { super.tryResolveTypeReference(normalizedTypeName) match { case None if GlobalTypes.kernelFunctions.contains(normalizedTypeName) => - Option(RubyType(s"${GlobalTypes.kernelPrefix}.$normalizedTypeName", List.empty, List.empty)) + Option(RubyType(RubyDefines.prefixAsKernelDefined(normalizedTypeName), List.empty, List.empty)) case None if GlobalTypes.bundledClasses.contains(normalizedTypeName) => - Option(RubyType(s"<${GlobalTypes.builtinPrefix}.$normalizedTypeName>", List.empty, List.empty)) + Option(RubyType(RubyDefines.prefixAsCoreType(normalizedTypeName), List.empty, List.empty)) case None => None case x => x @@ -340,4 +384,22 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) } } + /** @param identifier + * the name of the variable. + * @return + * the full name of the variable's scope, if available. + */ + def variableScopeFullName(identifier: String): Option[String] = { + stack + .collectFirst { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement + } + .map { + case ScopeElement(x: NamespaceLikeScope, _) => x.fullName + case ScopeElement(x: TypeLikeScope, _) => x.fullName + case ScopeElement(x: MethodLikeScope, _) => x.fullName + } + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala index f7661770e69c..402d7c76b93f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala @@ -1,6 +1,6 @@ package io.joern.rubysrc2cpg.datastructures -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{RubyFieldIdentifier, RubyNode} +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{RubyFieldIdentifier, RubyExpression} import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.datastructures.{NamespaceLikeScope, TypedScopeElement} import io.shiftleft.codepropertygraph.generated.nodes.NewBlock @@ -16,7 +16,7 @@ case class FieldDecl( typeFullName: String, isStatic: Boolean, isInitialized: Boolean, - node: RubyNode & RubyFieldIdentifier + node: RubyExpression & RubyFieldIdentifier ) extends TypedScopeElement /** A type-like scope with a full name. @@ -35,7 +35,7 @@ trait TypeLikeScope extends TypedScopeElement { * the relative file name. */ case class ProgramScope(fileName: String) extends TypeLikeScope { - override def fullName: String = s"$fileName:${Defines.Program}" + override def fullName: String = s"$fileName${Defines.Main}" } /** A Ruby module/abstract class. @@ -55,12 +55,15 @@ case class TypeScope(fullName: String, fields: List[FieldDecl]) extends TypeLike */ trait MethodLikeScope extends TypedScopeElement { def fullName: String + def procParam: Either[String, String] + def hasYield: Boolean } case class MethodScope(fullName: String, procParam: Either[String, String], hasYield: Boolean = false) extends MethodLikeScope -case class ConstructorScope(fullName: String) extends MethodLikeScope +case class ConstructorScope(fullName: String, procParam: Either[String, String], hasYield: Boolean = false) + extends MethodLikeScope /** Represents scope objects that map to a block node. */ diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/ParseInternalStructures.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/ParseInternalStructures.scala deleted file mode 100644 index 6f810a9dd14e..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/ParseInternalStructures.scala +++ /dev/null @@ -1,156 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated - -import io.joern.rubysrc2cpg.RubySrc2Cpg -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.utils.PackageTable -import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.misc.Interval -import org.slf4j.LoggerFactory - -import java.io.File as JFile -import scala.collection.mutable -import scala.jdk.CollectionConverters.* -import scala.util.{Failure, Try} - -class ParseInternalStructures( - parsedFiles: List[(String, DeprecatedRubyParser.ProgramContext)], - projectRoot: Option[String] = None -) { - - private val logger = LoggerFactory.getLogger(getClass) - - def populatePackageTable(): Unit = { - parsedFiles.foreach { case (fileName, programCtx) => - Try { - val relativeFilename: String = - projectRoot.map(fileName.stripPrefix).map(_.stripPrefix(JFile.separator)).getOrElse(fileName) - implicit val classStack: mutable.Stack[String] = mutable.Stack[String]() - parseForStructures(relativeFilename, programCtx) - } match { - case Failure(exception) => - logger.warn(s"Exception encountered while scanning for internal structures in file '$fileName'", exception) - case _ => // do nothing - } - } - } - - private def parseForStructures(relativeFilename: String, programCtx: ProgramContext)(implicit - classStack: mutable.Stack[String] - ): Unit = { - val name = ":program" - val fullName = s"$relativeFilename:$name" - classStack.push(fullName) - if ( - programCtx.compoundStatement() != null && - programCtx.compoundStatement().statements() != null - ) { - programCtx.compoundStatement().statements().statement().asScala.foreach(parseStatement) - } - classStack.pop() - } - - private def parseStatement(ctx: StatementContext)(implicit classStack: mutable.Stack[String]): Unit = ctx match { - case ctx: ExpressionOrCommandStatementContext => parseExpressionOrCommand(ctx.expressionOrCommand()) - case _ => - } - - private def parseExpressionOrCommand( - ctx: ExpressionOrCommandContext - )(implicit classStack: mutable.Stack[String]): Unit = ctx match { - case ctx: ExpressionExpressionOrCommandContext => parseExpressionContext(ctx.expression()) - case _ => - } - - private def parseExpressionContext(ctx: ExpressionContext)(implicit classStack: mutable.Stack[String]): Unit = - ctx match { - case ctx: PrimaryExpressionContext => parsePrimaryContext(ctx.primary()) - case _ => - } - - private def parsePrimaryContext(ctx: PrimaryContext)(implicit classStack: mutable.Stack[String]): Unit = ctx match { - case ctx: MethodDefinitionPrimaryContext => parseMethodDefinitionContext(ctx.methodDefinition()) - case ctx: ModuleDefinitionPrimaryContext => parseModuleDefinitionContext(ctx.moduleDefinition()) - case ctx: ClassDefinitionPrimaryContext => parseClassDefinition(ctx.classDefinition()) - case _ => - } - - private def parseModuleDefinitionContext( - moduleDefinitionContext: ModuleDefinitionContext - )(implicit classStack: mutable.Stack[String]): Unit = { - val className = moduleDefinitionContext.classOrModuleReference().CONSTANT_IDENTIFIER().getText - classStack.push(className) - parseClassBody(moduleDefinitionContext.bodyStatement()) - } - - private def parseClassDefinition( - classDef: ClassDefinitionContext - )(implicit classStack: mutable.Stack[String]): Unit = { - Option(classDef).foreach { ctx => - Option(ctx.classOrModuleReference()).map(_.CONSTANT_IDENTIFIER().getText).foreach { className => - classStack.push(className) - parseClassBody(ctx.bodyStatement()) - } - } - } - - private def parseClassBody(ctx: BodyStatementContext)(implicit classStack: mutable.Stack[String]): Unit = { - Option(ctx).map(_.compoundStatement()).map(_.statements()).foreach(_.statement().asScala.foreach(parseStatement)) - } - - private def parseMethodDefinitionContext( - ctx: MethodDefinitionContext - )(implicit classStack: mutable.Stack[String]): Unit = { - val maybeMethodName = Option(ctx.methodNamePart()) match - case Some(ctxMethodNamePart) => - readMethodNamePart(ctxMethodNamePart) - case None => - readMethodIdentifier(ctx.methodIdentifier()) - - maybeMethodName.foreach { methodName => - val classType = if (classStack.isEmpty) "Standalone" else classStack.top - val classPath = classStack.reverse.toList.mkString(".") - RubySrc2Cpg.packageTableInfo.addPackageMethod(PackageTable.InternalModule, methodName, classPath, classType) - } - } - - private def readMethodNamePart(ctx: MethodNamePartContext): Option[String] = { - ctx match - case context: SimpleMethodNamePartContext => - Option(context.definedMethodName().methodName()) match - case Some(methodNameCtx) => Try(methodNameCtx.methodIdentifier().getText).toOption - case None => None - case context: SingletonMethodNamePartContext => - Option(context.definedMethodName().methodName()) match - case Some(methodNameCtx) => Try(methodNameCtx.methodIdentifier().getText).toOption - case None => None - case _ => None - } - - private def readMethodIdentifier(ctx: MethodIdentifierContext): Option[String] = { - if (ctx.methodOnlyIdentifier() != null) { - readMethodOnlyIdentifier(ctx.methodOnlyIdentifier()) - } else if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null) { - Option(ctx.LOCAL_VARIABLE_IDENTIFIER().getSymbol.getText) - } else { - None - } - } - - private def readMethodOnlyIdentifier(ctx: MethodOnlyIdentifierContext): Option[String] = { - if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null || ctx.CONSTANT_IDENTIFIER() != null) { - text(ctx) - } else { - None - } - } - - private def text(ctx: ParserRuleContext): Option[String] = Try { - val a = ctx.getStart.getStartIndex - val b = ctx.getStop.getStopIndex - val intv = new Interval(a, b) - val input = ctx.getStart.getInputStream - input.getText(intv) - }.toOption - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AntlrParser.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AntlrParser.scala deleted file mode 100644 index 89ce67e1f3e8..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AntlrParser.scala +++ /dev/null @@ -1,71 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.{ - DeprecatedRubyLexer, - DeprecatedRubyLexerPostProcessor, - DeprecatedRubyParser -} -import org.antlr.v4.runtime.* -import org.antlr.v4.runtime.atn.ATN -import org.antlr.v4.runtime.dfa.DFA -import org.slf4j.LoggerFactory - -import scala.util.Try - -/** A consumable wrapper for the RubyParser class used to parse the given file and be disposed thereafter. - * @param filename - * the file path to the file to be parsed. - */ -class AntlrParser(filename: String) { - - private val charStream = CharStreams.fromFileName(filename) - private val lexer = new DeprecatedRubyLexer(charStream) - private val tokenStream = new CommonTokenStream(DeprecatedRubyLexerPostProcessor(lexer)) - val parser: DeprecatedRubyParser = new DeprecatedRubyParser(tokenStream) - - def parse(): Try[DeprecatedRubyParser.ProgramContext] = Try(parser.program()) -} - -/** A re-usable parser object that clears the ANTLR DFA-cache if it determines that the memory usage is becoming large. - * Once this parser is closed, the whole cache is evicted. - * - * This is done in this way since clearing the cache after each file is inefficient, since the cache must be re-built - * every time, but the cache can become unnecessarily large at times. The cache also does not evict itself at the end - * of parsing. - * - * @param clearLimit - * the percentage of used heap to clear the DFA-cache on. - */ -class ResourceManagedParser(clearLimit: Double) extends AutoCloseable { - - private val logger = LoggerFactory.getLogger(getClass) - private val runtime = Runtime.getRuntime - private var maybeDecisionToDFA: Option[Array[DFA]] = None - private var maybeAtn: Option[ATN] = None - - def parse(filename: String): Try[DeprecatedRubyParser.ProgramContext] = { - val antlrParser = AntlrParser(filename) - val interp = antlrParser.parser.getInterpreter - // We need to grab a live instance in order to get the static variables as they are protected from static access - maybeDecisionToDFA = Option(interp.decisionToDFA) - maybeAtn = Option(interp.atn) - val usedMemory = runtime.freeMemory.toDouble / runtime.totalMemory.toDouble - if (usedMemory >= clearLimit) { - logger.info(s"Runtime memory consumption at $usedMemory, clearing ANTLR DFA cache") - clearDFA() - } - antlrParser.parse() - } - - /** Clears the shared DFA cache. - */ - private def clearDFA(): Unit = if (maybeDecisionToDFA.isDefined && maybeAtn.isDefined) { - val decisionToDFA = maybeDecisionToDFA.get - val atn = maybeAtn.get - for (d <- decisionToDFA.indices) { - decisionToDFA(d) = new DFA(atn.getDecisionState(d), d) - } - } - - override def close(): Unit = clearDFA() -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreator.scala deleted file mode 100644 index 347319eca95f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreator.scala +++ /dev/null @@ -1,760 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.deprecated.utils.PackageContext -import io.joern.x2cpg.Ast.storeInDiffGraph -import io.joern.x2cpg.Defines.DynamicCallUnknownFullName -import io.joern.x2cpg.X2Cpg.stripQuotes -import io.joern.x2cpg.datastructures.Global -import io.joern.x2cpg.utils.NodeBuilders.newModifierNode -import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode, Defines as XDefines} -import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* -import org.antlr.v4.runtime.ParserRuleContext -import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate - -import java.io.File as JFile -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} -import scala.collection.immutable.Seq -import scala.collection.mutable -import scala.collection.mutable.ListBuffer -import scala.jdk.CollectionConverters.* -import scala.util.{Failure, Success} - -class AstCreator( - filename: String, - programCtx: DeprecatedRubyParser.ProgramContext, - protected val packageContext: PackageContext, - projectRoot: Option[String] = None -)(implicit withSchemaValidation: ValidationMode) - extends AstCreatorBase(filename) - with AstNodeBuilder[ParserRuleContext, AstCreator] - with AstForPrimitivesCreator - with AstForStatementsCreator(filename) - with AstForFunctionsCreator - with AstForExpressionsCreator - with AstForDeclarationsCreator - with AstForTypesCreator - with AstForControlStructuresCreator - with AstCreatorHelper - with AstForHereDocsCreator { - - protected val scope: RubyScope = new RubyScope() - - private val logger = LoggerFactory.getLogger(this.getClass) - - protected val classStack: mutable.Stack[String] = mutable.Stack[String]() - - protected val packageStack: mutable.Stack[String] = mutable.Stack[String]() - - protected val pathSep = "." - - protected val relativeFilename: String = - projectRoot.map(filename.stripPrefix).map(_.stripPrefix(JFile.separator)).getOrElse(filename) - - // The below are for adding implicit return nodes to methods - - // This is true if the last statement of a method is being processed. The last statement could be a if-else as well - protected val processingLastMethodStatement: AtomicBoolean = AtomicBoolean(false) - // a monotonically increasing block id unique within this file - protected val blockIdCounter: AtomicInteger = AtomicInteger(1) - // block id of the block currently being processed - protected val currentBlockId: AtomicInteger = AtomicInteger(0) - /* - * This is a hash of parent block id ---> child block id. If there are multiple children, any one child can be present. - * The value of this entry for a block is read AFTER its last statement has been processed. Absence of the the block - * in this hash implies this is a leaf block. - */ - protected val blockChildHash: mutable.Map[Int, Int] = mutable.HashMap[Int, Int]() - - private val builtInCallNames = mutable.HashSet[String]() - // Hashmap to store used variable names, to avoid duplicates in case of un-named variables - protected val usedVariableNames = mutable.HashMap.empty[String, Int] - - override def createAst(): BatchedUpdate.DiffGraphBuilder = createAstForProgramCtx(programCtx) - - private def createAstForProgramCtx(programCtx: DeprecatedRubyParser.ProgramContext) = { - val name = ":program" - val fullName = s"$relativeFilename:$name" - val programMethod = - methodNode( - programCtx, - name, - name, - fullName, - None, - relativeFilename, - Option(NodeTypes.TYPE_DECL), - Option(fullName) - ) - - classStack.push(fullName) - scope.pushNewScope(programMethod) - - val statementAsts = - if ( - programCtx.compoundStatement() != null && - programCtx.compoundStatement().statements() != null - ) { - astForStatements(programCtx.compoundStatement().statements(), false, false) ++ blockMethods - } else { - logger.error(s"File $filename has no compound statement. Needs to be examined") - List[Ast](Ast()) - } - - val methodRetNode = methodReturnNode(programCtx, Defines.Any) - - // For all the builtIn's encountered create assignment ast, minus user-defined methods with the same name - val lineColNum = 1 - val builtInMethodAst = builtInCallNames - .filterNot(methodNameToMethod.contains) - .map { builtInCallName => - val identifierNode = NewIdentifier() - .code(builtInCallName) - .name(builtInCallName) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - .typeFullName(Defines.Any) - scope.addToScope(builtInCallName, identifierNode) - val typeRefNode = NewTypeRef() - .code(prefixAsBuiltin(builtInCallName)) - .typeFullName(prefixAsBuiltin(builtInCallName)) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - astForAssignment(identifierNode, typeRefNode, Some(lineColNum), Some(lineColNum)) - } - .toList - - val methodRefAssignmentAsts = methodNameToMethod.values - .filterNot(_.astParentType == NodeTypes.TYPE_DECL) - .map { methodNode => - // Create a methodRefNode and assign it to the identifier version of the method, which will help in type propagation to resolve calls - val methodRefNode = NewMethodRef() - .code("def " + methodNode.name + "(...)") - .methodFullName(methodNode.fullName) - .typeFullName(methodNode.fullName) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - - val methodNameIdentifier = NewIdentifier() - .code(methodNode.name) - .name(methodNode.name) - .typeFullName(Defines.Any) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - scope.addToScope(methodNode.name, methodNameIdentifier) - val methodRefAssignmentAst = - astForAssignment(methodNameIdentifier, methodRefNode, methodNode.lineNumber, methodNode.columnNumber) - methodRefAssignmentAst - } - .toList - - val typeRefAssignmentAst = typeDeclNameToTypeDecl.values.map { typeDeclNode => - - val typeRefNode = NewTypeRef() - .code("class " + typeDeclNode.name + "(...)") - .typeFullName(typeDeclNode.fullName) - .lineNumber(typeDeclNode.lineNumber) - .columnNumber(typeDeclNode.columnNumber) - - val typeDeclNameIdentifier = NewIdentifier() - .code(typeDeclNode.name) - .name(typeDeclNode.name) - .typeFullName(Defines.Any) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - scope.addToScope(typeDeclNode.name, typeDeclNameIdentifier) - val typeRefAssignmentAst = - astForAssignment(typeDeclNameIdentifier, typeRefNode, typeDeclNode.lineNumber, typeDeclNode.columnNumber) - typeRefAssignmentAst - } - - val methodDefInArgumentAsts = methodDefInArgument.toList - val locals = scope.createAndLinkLocalNodes(diffGraph).map(Ast.apply) - val programAst = - methodAst( - programMethod, - Seq.empty[Ast], - blockAst( - blockNode(programCtx), - locals ++ builtInMethodAst ++ methodRefAssignmentAsts ++ typeRefAssignmentAst ++ methodDefInArgumentAsts ++ statementAsts.toList - ), - methodRetNode, - newModifierNode(ModifierTypes.MODULE) :: Nil - ) - - scope.popScope() - - val fileNode = NewFile().name(relativeFilename).order(1) - val namespaceBlock = globalNamespaceBlock() - val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(programAst)) - - classStack.popAll() - - storeInDiffGraph(ast, diffGraph) - diffGraph - } - - def astForPrimaryContext(ctx: PrimaryContext): Seq[Ast] = ctx match { - case ctx: ClassDefinitionPrimaryContext if ctx.hasClassDefinition => astForClassDeclaration(ctx) - case ctx: ClassDefinitionPrimaryContext => astForClassExpression(ctx) - case ctx: ModuleDefinitionPrimaryContext => astForModuleDefinitionPrimaryContext(ctx) - case ctx: MethodDefinitionPrimaryContext => astForMethodDefinitionContext(ctx.methodDefinition()) - case ctx: ProcDefinitionPrimaryContext => astForProcDefinitionContext(ctx.procDefinition()) - case ctx: YieldWithOptionalArgumentPrimaryContext => - Seq(astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments()))) - case ctx: IfExpressionPrimaryContext => Seq(astForIfExpression(ctx.ifExpression())) - case ctx: UnlessExpressionPrimaryContext => Seq(astForUnlessExpression(ctx.unlessExpression())) - case ctx: CaseExpressionPrimaryContext => astForCaseExpressionPrimaryContext(ctx) - case ctx: WhileExpressionPrimaryContext => Seq(astForWhileExpression(ctx.whileExpression())) - case ctx: UntilExpressionPrimaryContext => Seq(astForUntilExpression(ctx.untilExpression())) - case ctx: ForExpressionPrimaryContext => Seq(astForForExpression(ctx.forExpression())) - case ctx: ReturnWithParenthesesPrimaryContext => - Seq(returnAst(returnNode(ctx, code(ctx)), astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()))) - case ctx: JumpExpressionPrimaryContext => astForJumpExpressionPrimaryContext(ctx) - case ctx: BeginExpressionPrimaryContext => astForBeginExpressionPrimaryContext(ctx) - case ctx: GroupingExpressionPrimaryContext => astForCompoundStatement(ctx.compoundStatement(), false, false) - case ctx: VariableReferencePrimaryContext => Seq(astForVariableReference(ctx.variableReference())) - case ctx: SimpleScopedConstantReferencePrimaryContext => - astForSimpleScopedConstantReferencePrimaryContext(ctx) - case ctx: ChainedScopedConstantReferencePrimaryContext => - astForChainedScopedConstantReferencePrimaryContext(ctx) - case ctx: ArrayConstructorPrimaryContext => astForArrayLiteral(ctx.arrayConstructor()) - case ctx: HashConstructorPrimaryContext => astForHashConstructorPrimaryContext(ctx) - case ctx: LiteralPrimaryContext => astForLiteralPrimaryExpression(ctx) - case ctx: StringExpressionPrimaryContext => astForStringExpression(ctx.stringExpression) - case ctx: QuotedStringExpressionPrimaryContext => astForQuotedStringExpression(ctx.quotedStringExpression) - case ctx: RegexInterpolationPrimaryContext => - astForRegexInterpolationPrimaryContext(ctx.regexInterpolation) - case ctx: QuotedRegexInterpolationPrimaryContext => astForQuotedRegexInterpolation(ctx.quotedRegexInterpolation) - case ctx: IsDefinedPrimaryContext => Seq(astForIsDefinedPrimaryExpression(ctx)) - case ctx: SuperExpressionPrimaryContext => Seq(astForSuperExpression(ctx)) - case ctx: IndexingExpressionPrimaryContext => astForIndexingExpressionPrimaryContext(ctx) - case ctx: MethodOnlyIdentifierPrimaryContext => astForMethodOnlyIdentifier(ctx.methodOnlyIdentifier()) - case ctx: InvocationWithBlockOnlyPrimaryContext => astForInvocationWithBlockOnlyPrimaryContext(ctx) - case ctx: InvocationWithParenthesesPrimaryContext => astForInvocationWithParenthesesPrimaryContext(ctx) - case ctx: ChainedInvocationPrimaryContext => astForChainedInvocationPrimaryContext(ctx) - case ctx: ChainedInvocationWithoutArgumentsPrimaryContext => - astForChainedInvocationWithoutArgumentsPrimaryContext(ctx) - case _ => - logger.error(s"astForPrimaryContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - def astForExpressionContext(ctx: ExpressionContext): Seq[Ast] = ctx match { - case ctx: PrimaryExpressionContext => astForPrimaryContext(ctx.primary()) - case ctx: UnaryExpressionContext => Seq(astForUnaryExpression(ctx)) - case ctx: PowerExpressionContext => Seq(astForPowerExpression(ctx)) - case ctx: UnaryMinusExpressionContext => Seq(astForUnaryMinusExpression(ctx)) - case ctx: MultiplicativeExpressionContext => Seq(astForMultiplicativeExpression(ctx)) - case ctx: AdditiveExpressionContext => Seq(astForAdditiveExpression(ctx)) - case ctx: BitwiseShiftExpressionContext => Seq(astForBitwiseShiftExpression(ctx)) - case ctx: BitwiseAndExpressionContext => Seq(astForBitwiseAndExpression(ctx)) - case ctx: BitwiseOrExpressionContext => Seq(astForBitwiseOrExpression(ctx)) - case ctx: RelationalExpressionContext => Seq(astForRelationalExpression(ctx)) - case ctx: EqualityExpressionContext => Seq(astForEqualityExpression(ctx)) - case ctx: OperatorAndExpressionContext => Seq(astForAndExpression(ctx)) - case ctx: OperatorOrExpressionContext => Seq(astForOrExpression(ctx)) - case ctx: RangeExpressionContext => astForRangeExpressionContext(ctx) - case ctx: ConditionalOperatorExpressionContext => Seq(astForTernaryConditionalOperator(ctx)) - case ctx: SingleAssignmentExpressionContext => astForSingleAssignmentExpressionContext(ctx) - case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx) - case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx)) - case _ => - logger.error(s"astForExpressionContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - protected def astForIndexingArgumentsContext(ctx: IndexingArgumentsContext): Seq[Ast] = ctx match { - case ctx: DeprecatedRubyParser.CommandOnlyIndexingArgumentsContext => - astForCommand(ctx.command()) - case ctx: DeprecatedRubyParser.ExpressionsOnlyIndexingArgumentsContext => - ctx - .expressions() - .expression() - .asScala - .flatMap(astForExpressionContext) - .toSeq - case ctx: DeprecatedRubyParser.ExpressionsAndSplattingIndexingArgumentsContext => - val expAsts = ctx - .expressions() - .expression() - .asScala - .flatMap(astForExpressionContext) - .toSeq - val splatAsts = astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand()) - val callNode = createOpCall(ctx.COMMA, Operators.arrayInitializer, code(ctx)) - Seq(callAst(callNode, expAsts ++ splatAsts)) - case ctx: AssociationsOnlyIndexingArgumentsContext => - astForAssociationsContext(ctx.associations()) - case ctx: DeprecatedRubyParser.SplattingOnlyIndexingArgumentsContext => - astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand()) - case _ => - logger.error(s"astForIndexingArgumentsContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - private def astForBeginExpressionPrimaryContext(ctx: BeginExpressionPrimaryContext): Seq[Ast] = - astForBodyStatementContext(ctx.beginExpression().bodyStatement()) - - private def astForChainedInvocationPrimaryContext(ctx: ChainedInvocationPrimaryContext): Seq[Ast] = { - val hasBlockStmt = ctx.block() != null - val primaryAst = astForPrimaryContext(ctx.primary()) - val methodNameAst = - if (!hasBlockStmt && code(ctx.methodName()) == "new") astForCallToConstructor(ctx.methodName(), primaryAst) - else astForMethodNameContext(ctx.methodName()) - - val terminalNode = if (ctx.COLON2() != null) { - ctx.COLON2() - } else if (ctx.DOT() != null) { - ctx.DOT() - } else { - ctx.AMPDOT() - } - - val argsAst = if (ctx.argumentsWithParentheses() != null) { - astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()) - } else { - Seq() - } - - if (hasBlockStmt) { - val blockName = methodNameAst.head.nodes.head - .asInstanceOf[NewCall] - .name - val blockMethodName = blockName + terminalNode.getSymbol.getLine - val blockMethodAsts = - astForBlockFunction( - ctxStmt = ctx.block().compoundStatement.statements(), - ctxParam = ctx.block().blockParameter, - blockMethodName, - line(ctx).head, - column(ctx).head, - lineEnd(ctx).head, - columnEnd(ctx).head - ) - val blockMethodNode = - blockMethodAsts.head.nodes.head - .asInstanceOf[NewMethod] - - blockMethods.addOne(blockMethodAsts.head) - - val callNode = NewCall() - .name(blockName) - .methodFullName(blockMethodNode.fullName) - .typeFullName(Defines.Any) - .code(blockMethodNode.code) - .lineNumber(blockMethodNode.lineNumber) - .columnNumber(blockMethodNode.columnNumber) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - - val methodRefNode = NewMethodRef() - .methodFullName(blockMethodNode.fullName) - .typeFullName(Defines.Any) - .code(blockMethodNode.code) - .lineNumber(blockMethodNode.lineNumber) - .columnNumber(blockMethodNode.columnNumber) - - Seq(callAst(callNode, argsAst ++ Seq(Ast(methodRefNode)), primaryAst.headOption)) - } else { - val callNode = methodNameAst.head.nodes - .filter(node => node.isInstanceOf[NewCall]) - .head - .asInstanceOf[NewCall] - - if (callNode.name == "call" && ctx.primary().isInstanceOf[ProcDefinitionPrimaryContext]) { - // this is a proc.call - val baseCallNode = primaryAst.head.nodes.head.asInstanceOf[NewCall] - Seq(callAst(baseCallNode, argsAst)) - } else { - callNode - .code(text(ctx)) - .lineNumber(terminalNode.lineNumber) - .columnNumber(terminalNode.columnNumber) - - primaryAst.headOption.flatMap(_.root) match { - case Some(methodNode: NewMethod) => - val methodRefNode = NewMethodRef() - .code("def " + methodNode.name + "(...)") - .methodFullName(methodNode.fullName) - .typeFullName(Defines.Any) - blockMethods.addOne(primaryAst.head) - Seq(callAst(callNode, Seq(Ast(methodRefNode)) ++ argsAst)) - case _ => - Seq(callAst(callNode, argsAst, primaryAst.headOption)) - } - } - } - } - - private def astForCallToConstructor(ctx: MethodNameContext, receiverAst: Seq[Ast]): Seq[Ast] = { - val receiverTypeName = receiverAst.flatMap(_.root).collectFirst { case x: NewIdentifier => x } match - case Some(receiverNode) if receiverNode.typeFullName != Defines.Any => - receiverNode.typeFullName - case Some(receiverNode) if typeDeclNameToTypeDecl.contains(receiverNode.name) => - typeDeclNameToTypeDecl(receiverNode.name).fullName - case _ => Defines.Any - - val name = XDefines.ConstructorMethodName - val (methodFullName, typeFullName) = - if (receiverTypeName != Defines.Any) - (Seq(receiverTypeName, XDefines.ConstructorMethodName).mkString(pathSep), receiverTypeName) - else (XDefines.DynamicCallUnknownFullName, Defines.Any) - - val constructorCall = - callNode(ctx, code(ctx), name, methodFullName, DispatchTypes.STATIC_DISPATCH, None, Option(typeFullName)) - Seq(Ast(constructorCall)) - } - - def astForChainedInvocationWithoutArgumentsPrimaryContext( - ctx: ChainedInvocationWithoutArgumentsPrimaryContext - ): Seq[Ast] = { - val methodNameAst = astForMethodNameContext(ctx.methodName()) - val baseAst = astForPrimaryContext(ctx.primary()) - - val blocksAst = if (ctx.block() != null) { - Seq(astForBlock(ctx.block())) - } else { - Seq() - } - val callNode = methodNameAst.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] - callNode - .code(text(ctx)) - .lineNumber(ctx.COLON2().getSymbol().getLine()) - .columnNumber(ctx.COLON2().getSymbol().getCharPositionInLine()) - Seq(callAst(callNode, baseAst ++ blocksAst)) - } - - private def astForChainedScopedConstantReferencePrimaryContext( - ctx: ChainedScopedConstantReferencePrimaryContext - ): Seq[Ast] = { - val primaryAst = astForPrimaryContext(ctx.primary()) - val localVar = ctx.CONSTANT_IDENTIFIER() - val varSymbol = localVar.getSymbol - val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any)) - val constAst = Ast(node) - - val operatorName = getOperatorName(ctx.COLON2().getSymbol) - val callNode = createOpCall(ctx.COLON2, operatorName, code(ctx)) - Seq(callAst(callNode, primaryAst ++ Seq(constAst))) - } - - private def astForGroupedLeftHandSideContext(ctx: GroupedLeftHandSideContext): Seq[Ast] = { - astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) - } - - private def astForPackingLeftHandSideContext(ctx: PackingLeftHandSideContext): Seq[Ast] = { - astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) - } - - def astForMultipleLeftHandSideContext(ctx: MultipleLeftHandSideContext): Seq[Ast] = ctx match { - case ctx: MultipleLeftHandSideAndpackingLeftHandSideMultipleLeftHandSideContext => - val multipleLHSAsts = ctx.multipleLeftHandSideItem.asScala.flatMap { item => - if (item.singleLeftHandSide != null) { - astForSingleLeftHandSideContext(item.singleLeftHandSide()) - } else { - astForGroupedLeftHandSideContext(item.groupedLeftHandSide()) - } - }.toList - - val paramAsts = - if (ctx.packingLeftHandSide() != null) { - val packingLHSAst = astForPackingLeftHandSideContext(ctx.packingLeftHandSide()) - multipleLHSAsts ++ packingLHSAst - } else { - multipleLHSAsts - } - - paramAsts - - case ctx: PackingLeftHandSideOnlyMultipleLeftHandSideContext => - astForPackingLeftHandSideContext(ctx.packingLeftHandSide()) - case ctx: GroupedLeftHandSideOnlyMultipleLeftHandSideContext => - astForGroupedLeftHandSideContext(ctx.groupedLeftHandSide()) - case _ => - logger.error(s"astForMultipleLeftHandSideContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - // TODO: Clean-up and take into account other hash elements - private def astForHashConstructorPrimaryContext(ctx: HashConstructorPrimaryContext): Seq[Ast] = { - if (ctx.hashConstructor().hashConstructorElements() == null) return Seq(Ast()) - val hashCtorElemCtxs = ctx.hashConstructor().hashConstructorElements().hashConstructorElement().asScala - val associationCtxs = hashCtorElemCtxs.filter(_.association() != null).map(_.association()).toSeq - val expressionCtxs = hashCtorElemCtxs.filter(_.expression() != null).map(_.expression()).toSeq - expressionCtxs.flatMap(astForExpressionContext) ++ associationCtxs.flatMap(astForAssociationContext) - } - - def astForInvocationExpressionOrCommandContext(ctx: InvocationExpressionOrCommandContext): Seq[Ast] = { - if (ctx.EMARK() != null) { - val invocWOParenAsts = astForInvocationWithoutParenthesesContext(ctx.invocationWithoutParentheses()) - val operatorName = getOperatorName(ctx.EMARK().getSymbol) - val callNode = createOpCall(ctx.EMARK, operatorName, code(ctx)) - Seq(callAst(callNode, invocWOParenAsts)) - } else { - astForInvocationWithoutParenthesesContext(ctx.invocationWithoutParentheses()) - } - } - - private def astForInvocationWithoutParenthesesContext(ctx: InvocationWithoutParenthesesContext): Seq[Ast] = - ctx match { - case ctx: SingleCommandOnlyInvocationWithoutParenthesesContext => astForCommand(ctx.command()) - case ctx: ChainedCommandDoBlockInvocationWithoutParenthesesContext => - astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock()) - case ctx: ReturnArgsInvocationWithoutParenthesesContext => - val retNode = NewReturn() - .code(text(ctx)) - .lineNumber(ctx.RETURN().getSymbol.getLine) - .columnNumber(ctx.RETURN().getSymbol.getCharPositionInLine) - val argAst = Option(ctx.arguments).map(astForArguments).getOrElse(Seq()) - Seq(returnAst(retNode, argAst)) - case ctx: BreakArgsInvocationWithoutParenthesesContext => - astForBreakArgsInvocation(ctx) - case ctx: NextArgsInvocationWithoutParenthesesContext => - astForNextArgsInvocation(ctx) - case _ => - logger.error( - s"astForInvocationWithoutParenthesesContext() $relativeFilename, ${text(ctx)} All contexts mismatched." - ) - Seq(Ast()) - } - - private def astForInvocationWithBlockOnlyPrimaryContext(ctx: InvocationWithBlockOnlyPrimaryContext): Seq[Ast] = { - val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - val blockName = methodIdAst.head.nodes.head - .asInstanceOf[NewCall] - .name - - val isYieldMethod = if (blockName.endsWith(YIELD_SUFFIX)) { - val lookupMethodName = blockName.take(blockName.length - YIELD_SUFFIX.length) - methodNamesWithYield.contains(lookupMethodName) - } else { - false - } - - if (isYieldMethod) { - /* - * This is a yield block. Create a fake method out of it. The yield call will be a call to the yield block - */ - astForBlockFunction( - ctx.block().compoundStatement.statements(), - ctx.block().blockParameter, - blockName, - line(ctx).head, - lineEnd(ctx).head, - column(ctx).head, - columnEnd(ctx).head - ) - } else { - val blockAst = Seq(astForBlock(ctx.block())) - // this is expected to be a call node - val callNode = methodIdAst.head.nodes.head.asInstanceOf[NewCall] - Seq(callAst(callNode, blockAst)) - } - } - - private def astForInvocationWithParenthesesPrimaryContext(ctx: InvocationWithParenthesesPrimaryContext): Seq[Ast] = { - val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - val parenAst = astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()) - val callNode = methodIdAst.head.nodes.filter(_.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] - callNode.name(resolveAlias(callNode.name)) - - if (ctx.block() != null) { - val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) { - val lookupMethodName = callNode.name.take(callNode.name.length - YIELD_SUFFIX.length) - methodNamesWithYield.contains(lookupMethodName) - } else { - false - } - if (isYieldMethod) { - val methAst = astForBlock(ctx.block(), Some(callNode.name)) - blockMethods.addOne(methAst) - Seq(callAst(callNode, parenAst)) - } else { - val blockAst = Seq(astForBlock(ctx.block())) - Seq(callAst(callNode, parenAst ++ blockAst)) - } - } else - Seq(callAst(callNode, parenAst)) - } - - def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = { - val name = if (isYieldBlock) { - s"${resolveAlias(text(ctx))}$YIELD_SUFFIX" - } else { - val calleeName = resolveAlias(text(ctx)) - // Add the call name to the global builtIn callNames set - if (isBuiltin(calleeName)) builtInCallNames.add(calleeName) - calleeName - } - - callAst(callNode(ctx, code, name, DynamicCallUnknownFullName, DispatchTypes.STATIC_DISPATCH)) - } - - private def astForMethodOnlyIdentifier(ctx: MethodOnlyIdentifierContext): Seq[Ast] = { - if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null) { - Seq(astForCallNode(ctx, code(ctx))) - } else if (ctx.CONSTANT_IDENTIFIER() != null) { - Seq(astForCallNode(ctx, code(ctx))) - } else if (ctx.keyword() != null) { - Seq(astForCallNode(ctx, code(ctx.keyword()))) - } else { - Seq(Ast()) - } - } - - def astForMethodIdentifierContext(ctx: MethodIdentifierContext, code: String): Seq[Ast] = { - // the local/const identifiers are definitely method names - if (ctx.methodOnlyIdentifier() != null) { - astForMethodOnlyIdentifier(ctx.methodOnlyIdentifier()) - } else if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null) { - val localVar = ctx.LOCAL_VARIABLE_IDENTIFIER() - val varSymbol = localVar.getSymbol - Seq(astForCallNode(ctx, code, methodNamesWithYield.contains(varSymbol.getText))) - } else if (ctx.CONSTANT_IDENTIFIER() != null) { - Seq(astForCallNode(ctx, code)) - } else { - Seq.empty - } - } - - def astForRescueClauseContext(ctx: RescueClauseContext): Ast = { - val asts = ListBuffer.empty[Ast] - - if (ctx.exceptionClass() != null) { - val exceptionClass = ctx.exceptionClass() - - if (exceptionClass.expression() != null) { - asts.addAll(astForExpressionContext(exceptionClass.expression())) - } else { - asts.addAll(astForMultipleRightHandSideContext(exceptionClass.multipleRightHandSide())) - } - } - - if (ctx.exceptionVariableAssignment() != null) { - asts.addAll(astForSingleLeftHandSideContext(ctx.exceptionVariableAssignment().singleLeftHandSide())) - } - - asts.addAll(astForCompoundStatement(ctx.thenClause().compoundStatement(), false)) - blockAst(blockNode(ctx), asts.toList) - } - - private def astForSimpleScopedConstantReferencePrimaryContext( - ctx: SimpleScopedConstantReferencePrimaryContext - ): Seq[Ast] = { - val localVar = ctx.CONSTANT_IDENTIFIER - val varSymbol = localVar.getSymbol - val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any)) - - val operatorName = getOperatorName(ctx.COLON2.getSymbol) - val callNode = createOpCall(ctx.COLON2, operatorName, code(ctx)) - - Seq(callAst(callNode, Seq(Ast(node)))) - } - - private def astForCommandWithDoBlockContext(ctx: CommandWithDoBlockContext): Seq[Ast] = ctx match { - case ctx: ArgsAndDoBlockCommandWithDoBlockContext => - val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAst = Seq(astForDoBlock(ctx.doBlock())) - argsAsts ++ doBlockAst - case ctx: DeprecatedRubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext => - val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - methodIdAsts.headOption.flatMap(_.root) match - case Some(methodIdRoot: NewCall) if methodIdRoot.name == "define_method" => - ctx.argumentsWithoutParentheses.arguments.argument.asScala.headOption - .map { methodArg => - // TODO: methodArg will name the method, but this could be an identifier or even a string concatenation - // which is not assumed below - val methodName = stripQuotes(methodArg.getText) - Seq(astForDoBlock(ctx.doBlock(), Option(methodName))) - } - .getOrElse(Seq.empty) - case _ => - val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) - methodIdAsts ++ argsAsts ++ doBlockAsts - case ctx: DeprecatedRubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext => - val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) - val methodNameAsts = astForMethodNameContext(ctx.methodName()) - val primaryAsts = astForPrimaryContext(ctx.primary()) - primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts - case _ => - logger.error(s"astForCommandWithDoBlockContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - private def astForChainedCommandWithDoBlockContext(ctx: ChainedCommandWithDoBlockContext): Seq[Ast] = { - val cmdAsts = astForCommandWithDoBlockContext(ctx.commandWithDoBlock) - val mNameAsts = ctx.methodName.asScala.flatMap(astForMethodNameContext).toSeq - val apAsts = ctx - .argumentsWithParentheses() - .asScala - .flatMap(astForArgumentsWithParenthesesContext) - .toSeq - cmdAsts ++ mNameAsts ++ apAsts - } - - protected def astForArgumentsWithParenthesesContext(ctx: ArgumentsWithParenthesesContext): Seq[Ast] = ctx match { - case _: BlankArgsArgumentsWithParenthesesContext => Seq.empty - case ctx: ArgsOnlyArgumentsWithParenthesesContext => astForArguments(ctx.arguments) - case ctx: ExpressionsAndChainedCommandWithDoBlockArgumentsWithParenthesesContext => - val expAsts = ctx.expressions.expression.asScala - .flatMap(astForExpressionContext) - .toSeq - val ccDoBlock = astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock) - expAsts ++ ccDoBlock - case ctx: ChainedCommandWithDoBlockOnlyArgumentsWithParenthesesContext => - astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock) - case _ => - logger.error(s"astForArgumentsWithParenthesesContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - private def astForBlockParametersContext(ctx: BlockParametersContext): Seq[Ast] = - if (ctx.singleLeftHandSide != null) { - astForSingleLeftHandSideContext(ctx.singleLeftHandSide) - } else if (ctx.multipleLeftHandSide != null) { - astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide) - } else { - Seq.empty - } - - protected def astForBlockParameterContext(ctx: BlockParameterContext): Seq[Ast] = - if (ctx.blockParameters != null) { - astForBlockParametersContext(ctx.blockParameters) - } else { - Seq.empty - } - - def astForAssociationContext(ctx: AssociationContext): Seq[Ast] = { - val terminalNode = Option(ctx.COLON).getOrElse(ctx.EQGT) - val operatorText = getOperatorName(terminalNode.getSymbol) - val expressions = ctx.expression.asScala - - val callArgs = - Option(ctx.keyword) match { - case Some(ctxKeyword) => - val expr1Ast = astForCallNode(ctx, code(ctxKeyword)) - val expr2Asts = astForExpressionContext(expressions.head) - Seq(expr1Ast) ++ expr2Asts - case None => - val expr1Asts = astForExpressionContext(expressions.head) - val expr2Asts = expressions.lift(1).flatMap(astForExpressionContext) - expr1Asts ++ expr2Asts - } - - val callNode = createOpCall(terminalNode, operatorText, code(ctx)) - Seq(callAst(callNode, callArgs)) - } - - private def astForAssociationsContext(ctx: AssociationsContext): Seq[Ast] = { - ctx.association.asScala - .flatMap(astForAssociationContext) - .toSeq - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreatorHelper.scala deleted file mode 100644 index d8ad2c3c609f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstCreatorHelper.scala +++ /dev/null @@ -1,361 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines as RubyDefines -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} -import org.antlr.v4.runtime.misc.Interval -import org.antlr.v4.runtime.tree.TerminalNode -import org.antlr.v4.runtime.{ParserRuleContext, Token} - -import scala.collection.mutable -import scala.util.Try - -trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - import io.joern.rubysrc2cpg.deprecated.astcreation.GlobalTypes.* - - protected def line(ctx: ParserRuleContext): Option[Int] = - Try(ctx.getStart.getLine).toOption - - protected def column(ctx: ParserRuleContext): Option[Int] = - Try(ctx.getStart.getCharPositionInLine).toOption - - protected def lineEnd(ctx: ParserRuleContext): Option[Int] = - Try(ctx.getStop.getLine).toOption - - protected def columnEnd(ctx: ParserRuleContext): Option[Int] = - Try(ctx.getStop.getCharPositionInLine).toOption - - override def code(node: ParserRuleContext): String = shortenCode(text(node)) - - protected def text(ctx: ParserRuleContext): String = Try { - val a = ctx.getStart.getStartIndex - val b = ctx.getStop.getStopIndex - val intv = new Interval(a, b) - val input = ctx.getStart.getInputStream - input.getText(intv) - }.getOrElse("") - - protected def isBuiltin(x: String): Boolean = builtinFunctions.contains(x) - - protected def prefixAsBuiltin(x: String): String = s"$builtinPrefix$pathSep$x" - - protected def methodsWithName(name: String): List[String] = { - packageContext.packageTable.getMethodFullNameUsingName(methodName = name) - } - - private def methodTableToCallNode( - methodFullName: String, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq(), - ctx: Option[ParserRuleContext] = None - ): NewCall = { - callNode(ctx.orNull, code, name, methodFullName, DispatchTypes.DYNAMIC_DISPATCH, None, Option(typeFullName)) - .dynamicTypeHintFullName(dynamicTypeHints) - } - - /** Checks that the name is not `this` and that the method has been referred to more than just an initial assignment - * to METHOD_REF. - * - * @param name - * the identifier name. - * @return - * true if this appears to be more like a method call than an identifier. - */ - private def isMethodCall(name: String): Boolean = { - name != "this" && scope.numVariableReferences(name) == 0 - } - - protected def createIdentifierWithScope( - ctx: ParserRuleContext, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq(), - definitelyIdentifier: Boolean = false - ): NewNode = { - methodsWithName(name) match - case method :: _ if !definitelyIdentifier && isMethodCall(name) => - methodTableToCallNode(method, name, code, typeFullName, dynamicTypeHints, Option(ctx)) - case _ => - val newNode = identifierNode(ctx, name, code, typeFullName, dynamicTypeHints) - scope.addToScope(name, newNode) - newNode - } - - protected def createIdentifierWithScope( - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String], - lineNumber: Option[Int], - columnNumber: Option[Int], - definitelyIdentifier: Boolean - ): NewNode = { - methodsWithName(name) match - case method :: _ if !definitelyIdentifier && isMethodCall(name) => - methodTableToCallNode(method, name, code, typeFullName, dynamicTypeHints, None) - case _ => - val newNode = NewIdentifier() - .name(name) - .code(code) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - scope.addToScope(name, newNode) - newNode - } - - protected def createOpCall( - node: TerminalNode, - operation: String, - code: String, - typeFullName: String = RubyDefines.Any - ): NewCall = { - NewCall() - .name(operation) - .methodFullName(operation) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) - .typeFullName(typeFullName) - .code(code) - } - - protected def createLiteralNode( - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq.empty, - lineNumber: Option[Int] = None, - columnNumber: Option[Int] = None - ): NewLiteral = { - val newLiteral = NewLiteral() - .code(code) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - lineNumber.foreach(newLiteral.lineNumber(_)) - columnNumber.foreach(newLiteral.columnNumber(_)) - newLiteral - } - - protected def astForAssignment( - lhs: NewNode, - rhs: NewNode, - lineNumber: Option[Int] = None, - colNumber: Option[Int] = None - ): Ast = { - val code = Seq(lhs, rhs).collect { case x: AstNodeNew => x.code }.mkString(" = ") - val assignment = NewCall() - .name(Operators.assignment) - .methodFullName(Operators.assignment) - .code(code) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .lineNumber(lineNumber) - .columnNumber(colNumber) - - callAst(assignment, Seq(Ast(lhs), Ast(rhs))) - } - - protected def createThisIdentifier( - ctx: ParserRuleContext, - typeFullName: String = RubyDefines.Any, - dynamicTypeHints: List[String] = List.empty - ): NewIdentifier = - createIdentifierWithScope(ctx, "this", "this", typeFullName, dynamicTypeHints, true).asInstanceOf[NewIdentifier] - - protected def newFieldIdentifier(ctx: ParserRuleContext): NewFieldIdentifier = { - val c = code(ctx) - val name = c.replaceAll("@", "") - NewFieldIdentifier() - .code(c) - .canonicalName(name) - .lineNumber(ctx.start.getLine) - .columnNumber(ctx.start.getCharPositionInLine) - } - - protected def astForFieldAccess(ctx: ParserRuleContext, baseNode: NewNode): Ast = { - val fieldAccess = - callNode(ctx, code(ctx), Operators.fieldAccess, Operators.fieldAccess, DispatchTypes.STATIC_DISPATCH) - val fieldIdentifier = newFieldIdentifier(ctx) - val astChildren = Seq(baseNode, fieldIdentifier) - callAst(fieldAccess, astChildren.map(Ast.apply)) - } - - protected def createMethodParameterIn( - name: String, - lineNumber: Option[Int] = None, - colNumber: Option[Int] = None, - typeFullName: String = RubyDefines.Any, - order: Int = -1, - index: Int = -1 - ): NewMethodParameterIn = { - NewMethodParameterIn() - .name(name) - .code(name) - .lineNumber(lineNumber) - .typeFullName(typeFullName) - .columnNumber(colNumber) - .order(order) - .index(index) - } - - protected def getUnusedVariableNames( - usedVariableNames: mutable.HashMap[String, Int], - variableName: String - ): String = { - val counter = usedVariableNames.get(variableName).map(_ + 1).getOrElse(0) - val currentVariableName = s"${variableName}_$counter" - usedVariableNames.put(variableName, counter) - currentVariableName - } - - protected def astForControlStructure( - parserTypeName: String, - node: TerminalNode, - controlStructureType: String, - code: String - ): Ast = - Ast( - NewControlStructure() - .parserTypeName(parserTypeName) - .controlStructureType(controlStructureType) - .code(code) - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) - ) - - protected def returnNode(node: TerminalNode, code: String): NewReturn = - NewReturn() - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) - .code(code) - - protected def getOperatorName(token: Token): String = token.getType match { - case ASSIGNMENT_OPERATOR => Operators.assignment - case DOT2 => Operators.range - case DOT3 => Operators.range - case EMARK => Operators.not - case EQ => Operators.assignment - case COLON2 => RubyOperators.scopeResolution - case DOT => Operators.fieldAccess - case EQGT => RubyOperators.keyValueAssociation - case COLON => RubyOperators.activeRecordAssociation - case _ => RubyOperators.none - } - - implicit class TerminalNodeExt(n: TerminalNode) { - - def lineNumber: Int = n.getSymbol.getLine - - def columnNumber: Int = n.getSymbol.getCharPositionInLine - - } - -} - -object RubyOperators { - val none = ".none" - val patternMatch = ".patternMatch" - val notPatternMatch = ".notPatternMatch" - val scopeResolution = ".scopeResolution" - val defined = ".defined" - val keyValueAssociation = ".keyValueAssociation" - val activeRecordAssociation = ".activeRecordAssociation" - val undef = ".undef" - val superKeyword = ".super" - val stringConcatenation = ".stringConcatenation" - val formattedString = ".formatString" - val formattedValue = ".formatValue" -} - -object GlobalTypes { - val builtinPrefix = "__builtin" - /* Sources: - * https://ruby-doc.org/docs/ruby-doc-bundle/Manual/man-1.4/function.html - * https://ruby-doc.org/3.2.2/Kernel.html - * - * We comment-out methods that require an explicit "receiver" (target of member access.) - */ - val builtinFunctions = Set( - "Array", - "Complex", - "Float", - "Hash", - "Integer", - "Rational", - "String", - "__callee__", - "__dir__", - "__method__", - "abort", - "at_exit", - "autoload", - "autoload?", - "binding", - "block_given?", - "callcc", - "caller", - "caller_locations", - "catch", - "chomp", - "chomp!", - "chop", - "chop!", - // "class", - // "clone", - "eval", - "exec", - "exit", - "exit!", - "fail", - "fork", - "format", - // "frozen?", - "gets", - "global_variables", - "gsub", - "gsub!", - "iterator?", - "lambda", - "load", - "local_variables", - "loop", - "open", - "p", - "print", - "printf", - "proc", - "putc", - "puts", - "raise", - "rand", - "readline", - "readlines", - "require", - "require_relative", - "select", - "set_trace_func", - "sleep", - "spawn", - "sprintf", - "srand", - "sub", - "sub!", - "syscall", - "system", - "tap", - "test", - // "then", - "throw", - "trace_var", - // "trap", - "untrace_var", - "warn" - // "yield_self", - ) -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForControlStructuresCreator.scala deleted file mode 100644 index 89424fad8c9b..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForControlStructuresCreator.scala +++ /dev/null @@ -1,135 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.nodes.NewControlStructure - -import scala.jdk.CollectionConverters.* -trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private def astForWhenArgumentContext(ctx: WhenArgumentContext): Seq[Ast] = { - val expAsts = - ctx.expressions.expression.asScala - .flatMap(astForExpressionContext) - .toList - - if (ctx.splattingArgument != null) { - expAsts ++ astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand()) - } else { - expAsts - } - } - - protected def astForCaseExpressionPrimaryContext(ctx: CaseExpressionPrimaryContext): Seq[Ast] = { - val codeString = s"case ${Option(ctx.caseExpression().expressionOrCommand).map(code).getOrElse("")}".stripTrailing() - val switchNode = controlStructureNode(ctx, ControlStructureTypes.SWITCH, codeString) - val conditionAst = Option(ctx.caseExpression().expressionOrCommand()).toList - .flatMap(astForExpressionOrCommand) - .headOption - - val whenThenAstsList = ctx - .caseExpression() - .whenClause() - .asScala - .flatMap(wh => { - val whenNode = - jumpTargetNode(wh, "case", s"case ${code(wh)}", Option(wh.getClass.getSimpleName)) - - val whenACondAsts = astForWhenArgumentContext(wh.whenArgument()) - val thenAsts = astForCompoundStatement( - wh.thenClause().compoundStatement(), - isMethodBody = true, - canConsiderAsLeaf = false - ) ++ Seq(Ast(NewControlStructure().controlStructureType(ControlStructureTypes.BREAK))) - Seq(Ast(whenNode)) ++ whenACondAsts ++ thenAsts - }) - .toList - - val stmtAsts = whenThenAstsList ++ (Option(ctx.caseExpression().elseClause()) match - case Some(elseClause) => - Ast( - // name = "default" for behaviour determined by CfgCreator.cfgForJumpTarget - jumpTargetNode(elseClause, "default", "else", Option(elseClause.getClass.getSimpleName)) - ) +: astForCompoundStatement(elseClause.compoundStatement(), isMethodBody = true, canConsiderAsLeaf = false) - case None => Seq.empty[Ast] - ) - val block = blockNode(ctx.caseExpression()) - Seq(controlStructureAst(switchNode, conditionAst, Seq(Ast(block).withChildren(stmtAsts)))) - } - - protected def astForNextArgsInvocation(ctx: NextArgsInvocationWithoutParenthesesContext): Seq[Ast] = { - /* - * While this is a `CONTINUE` for now, if we detect that this is the LHS of an `IF` then this becomes a `RETURN` - */ - Seq( - astForControlStructure( - ctx.getClass.getSimpleName, - ctx.NEXT(), - ControlStructureTypes.CONTINUE, - Defines.ModifierNext - ).withChildren(astForArguments(ctx.arguments())) - ) - } - - protected def astForBreakArgsInvocation(ctx: BreakArgsInvocationWithoutParenthesesContext): Seq[Ast] = { - Option(ctx.arguments()) match { - case Some(args) => - /* - * This is break with args inside a block. The argument passed to break will be returned by the bloc - * Model this as a return since this is effectively a return - */ - val retNode = returnNode(ctx.BREAK(), code(ctx)) - val argAst = astForArguments(args) - Seq(returnAst(retNode, argAst)) - case None => - Seq( - astForControlStructure(ctx.getClass.getSimpleName, ctx.BREAK(), ControlStructureTypes.BREAK, code(ctx)) - .withChildren(astForArguments(ctx.arguments)) - ) - } - } - - protected def astForJumpExpressionPrimaryContext(ctx: JumpExpressionPrimaryContext): Seq[Ast] = { - val parserTypeName = ctx.getClass.getSimpleName - val controlStructureAst = ctx.jumpExpression() match - case expr if expr.BREAK() != null => - astForControlStructure(parserTypeName, expr.BREAK(), ControlStructureTypes.BREAK, code(ctx)) - case expr if expr.NEXT() != null => - astForControlStructure(parserTypeName, expr.NEXT(), ControlStructureTypes.CONTINUE, Defines.ModifierNext) - case expr if expr.REDO() != null => - astForControlStructure(parserTypeName, expr.REDO(), ControlStructureTypes.CONTINUE, Defines.ModifierRedo) - case expr if expr.RETRY() != null => - astForControlStructure(parserTypeName, expr.RETRY(), ControlStructureTypes.CONTINUE, Defines.ModifierRetry) - case _ => - Ast() - Seq(controlStructureAst) - } - - protected def astForRescueClause(ctx: BodyStatementContext): Ast = { - val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement) - val elseClauseAsts = Option(ctx.elseClause) match - case Some(ctx) => astForCompoundStatement(ctx.compoundStatement) - case None => Seq.empty - - /* - * TODO Conversion of last statement to return AST is needed here - * This can be done after the data flow engine issue with return from a try block is fixed - */ - val tryBodyAsts = compoundStatementAsts ++ elseClauseAsts - val tryBodyAst = blockAst(blockNode(ctx), tryBodyAsts.toList) - - val finallyAst = Option(ctx.ensureClause) match - case Some(ctx) => astForCompoundStatement(ctx.compoundStatement).headOption - case None => None - - val catchAsts = ctx.rescueClause.asScala - .map(astForRescueClauseContext) - .toSeq - - val tryNode = controlStructureNode(ctx, ControlStructureTypes.TRY, "try") - tryCatchAstWithOrder(tryNode, tryBodyAst, catchAsts, finallyAst) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForDeclarationsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForDeclarationsCreator.scala deleted file mode 100644 index f49b17a5495f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForDeclarationsCreator.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.x2cpg.Ast -import io.shiftleft.codepropertygraph.generated.nodes.{NewJumpTarget, NewLiteral} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} -import org.antlr.v4.runtime.ParserRuleContext -import org.slf4j.LoggerFactory - -import scala.jdk.CollectionConverters.CollectionHasAsScala - -trait AstForDeclarationsCreator { this: AstCreator => - - private val logger = LoggerFactory.getLogger(this.getClass) - - protected def astForArguments(ctx: ArgumentsContext): Seq[Ast] = { - ctx.argument().asScala.flatMap(astForArgument).toSeq - } - - protected def astForArgument(ctx: ArgumentContext): Seq[Ast] = { - ctx match { - case ctx: BlockArgumentArgumentContext => astForExpressionContext(ctx.blockArgument.expression) - case ctx: SplattingArgumentArgumentContext => astForExpressionOrCommand(ctx.splattingArgument.expressionOrCommand) - case ctx: ExpressionArgumentContext => astForExpressionContext(ctx.expression) - case ctx: AssociationArgumentContext => astForAssociationContext(ctx.association) - case ctx: CommandArgumentContext => astForCommand(ctx.command) - case ctx: HereDocArgumentContext => astForHereDocArgument(ctx) - case _ => - logger.error(s"astForArgument() $relativeFilename, ${ctx.getText} All contexts mismatched.") - Seq() - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForExpressionsCreator.scala deleted file mode 100644 index 0897863c28a9..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForExpressionsCreator.scala +++ /dev/null @@ -1,512 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.deprecated.passes.Defines.* -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} -import org.antlr.v4.runtime.ParserRuleContext -import org.slf4j.LoggerFactory - -import scala.collection.immutable.Set -import scala.jdk.CollectionConverters.CollectionHasAsScala - -trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private val logger = LoggerFactory.getLogger(this.getClass) - protected var lastModifier: Option[String] = None - - protected def astForPowerExpression(ctx: PowerExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.exponentiation, ctx.expression().asScala) - - protected def astForOrExpression(ctx: OperatorOrExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.or, ctx.expression().asScala) - - protected def astForAndExpression(ctx: OperatorAndExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.and, ctx.expression().asScala) - - protected def astForUnaryExpression(ctx: UnaryExpressionContext): Ast = ctx.op.getType match { - case TILDE => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) - case PLUS => astForBinaryOperatorExpression(ctx, Operators.plus, Seq(ctx.expression())) - case EMARK => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) - } - - protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.minus, Seq(ctx.expression())) - - protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Ast = ctx.op.getType match { - case PLUS => astForBinaryOperatorExpression(ctx, Operators.addition, ctx.expression().asScala) - case MINUS => astForBinaryOperatorExpression(ctx, Operators.subtraction, ctx.expression().asScala) - } - - protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Ast = ctx.op.getType match { - case STAR => astForMultiplicativeStarExpression(ctx) - case SLASH => astForMultiplicativeSlashExpression(ctx) - case PERCENT => astForMultiplicativePercentExpression(ctx) - } - - protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.multiplication, ctx.expression().asScala) - - protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.division, ctx.expression().asScala) - - protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.modulo, ctx.expression().asScala) - - protected def astForEqualityExpression(ctx: EqualityExpressionContext): Ast = ctx.op.getType match { - case LTEQGT => astForBinaryOperatorExpression(ctx, Operators.compare, ctx.expression().asScala) - case EQ2 => astForBinaryOperatorExpression(ctx, Operators.equals, ctx.expression().asScala) - case EQ3 => astForBinaryOperatorExpression(ctx, Operators.is, ctx.expression().asScala) - case EMARKEQ => astForBinaryOperatorExpression(ctx, Operators.notEquals, ctx.expression().asScala) - case EQTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.patternMatch, ctx.expression().asScala) - case EMARKTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.notPatternMatch, ctx.expression().asScala) - } - - protected def astForRelationalExpression(ctx: RelationalExpressionContext): Ast = ctx.op.getType match { - case GT => astForBinaryOperatorExpression(ctx, Operators.greaterThan, ctx.expression().asScala) - case GTEQ => astForBinaryOperatorExpression(ctx, Operators.greaterEqualsThan, ctx.expression().asScala) - case LT => astForBinaryOperatorExpression(ctx, Operators.lessThan, ctx.expression().asScala) - case LTEQ => astForBinaryOperatorExpression(ctx, Operators.lessEqualsThan, ctx.expression().asScala) - } - - protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Ast = ctx.op.getType match { - case BAR => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) - case CARET => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) - } - - protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, Operators.logicalAnd, ctx.expression().asScala) - - protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Ast = ctx.op.getType match { - case LT2 => astForBinaryOperatorExpression(ctx, Operators.shiftLeft, ctx.expression().asScala) - case GT2 => astForBinaryOperatorExpression(ctx, Operators.logicalShiftRight, ctx.expression().asScala) - } - - private def astForBinaryOperatorExpression( - ctx: ParserRuleContext, - name: String, - arguments: Iterable[ExpressionContext] - ): Ast = { - val argsAst = arguments.flatMap(astForExpressionContext) - val call = callNode(ctx, code(ctx), name, name, DispatchTypes.STATIC_DISPATCH) - callAst(call, argsAst.toList) - } - - protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Ast = - astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression())) - - // TODO: Maybe merge (in DeprecatedRubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression? - protected def astForIsDefinedPrimaryExpression(ctx: IsDefinedPrimaryContext): Ast = { - val argsAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val call = callNode(ctx, code(ctx), RubyOperators.defined, RubyOperators.defined, DispatchTypes.STATIC_DISPATCH) - callAst(call, argsAst.toList) - } - - protected def astForLiteralPrimaryExpression(ctx: LiteralPrimaryContext): Seq[Ast] = ctx.literal() match { - case ctx: NumericLiteralLiteralContext => Seq(astForNumericLiteral(ctx.numericLiteral())) - case ctx: SymbolLiteralContext => astForSymbol(ctx.symbol()) - case ctx: RegularExpressionLiteralContext => Seq(astForRegularExpressionLiteral(ctx)) - case ctx: HereDocLiteralContext => Seq(astForHereDocLiteral(ctx)) - case _ => - logger.error(s"astForLiteralPrimaryExpression() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq() - } - - private def astForSymbol(ctx: SymbolContext): Seq[Ast] = { - if ( - ctx.stringExpression() != null && ctx.stringExpression().children.get(0).isInstanceOf[StringInterpolationContext] - ) { - val node = callNode( - ctx, - code(ctx), - RubyOperators.formattedString, - RubyOperators.formattedString, - DispatchTypes.STATIC_DISPATCH, - None, - Option(Defines.Any) - ) - astForStringExpression(ctx.stringExpression()) ++ Seq(Ast(node)) - } else { - Seq(astForSymbolLiteral(ctx)) - } - } - - protected def astForMultipleRightHandSideContext(ctx: MultipleRightHandSideContext): Seq[Ast] = - if (ctx == null) { - Seq.empty - } else { - val expCmd = ctx.expressionOrCommands() - val exprAsts = Option(expCmd) match - case Some(expCmd) => - expCmd.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand).toSeq - case None => - Seq.empty - - if (ctx.splattingArgument != null) { - val splattingAsts = astForExpressionOrCommand(ctx.splattingArgument.expressionOrCommand) - exprAsts ++ splattingAsts - } else { - exprAsts - } - } - - protected def astForSingleLeftHandSideContext(ctx: SingleLeftHandSideContext): Seq[Ast] = ctx match { - case ctx: VariableIdentifierOnlySingleLeftHandSideContext => - Seq(astForVariableIdentifierHelper(ctx.variableIdentifier, true)) - case ctx: PrimaryInsideBracketsSingleLeftHandSideContext => - val primaryAsts = astForPrimaryContext(ctx.primary) - val argsAsts = astForArguments(ctx.arguments) - val indexAccessCall = createOpCall(ctx.LBRACK, Operators.indexAccess, code(ctx)) - Seq(callAst(indexAccessCall, primaryAsts ++ argsAsts)) - case ctx: XdotySingleLeftHandSideContext => - // TODO handle obj.foo=arg being interpreted as obj.foo(arg) here. - val xAsts = astForPrimaryContext(ctx.primary) - - Seq(ctx.LOCAL_VARIABLE_IDENTIFIER, ctx.CONSTANT_IDENTIFIER) - .flatMap(Option(_)) - .headOption match - case Some(localVar) => - val name = localVar.getSymbol.getText - val node = createIdentifierWithScope(ctx, name, name, Defines.Any, List(Defines.Any), true) - val yAst = Ast(node) - - val callNode = createOpCall(localVar, Operators.fieldAccess, code(ctx)) - Seq(callAst(callNode, xAsts ++ Seq(yAst))) - case None => - Seq.empty - case ctx: ScopedConstantAccessSingleLeftHandSideContext => - val localVar = ctx.CONSTANT_IDENTIFIER - val varSymbol = localVar.getSymbol - val node = - createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any), true) - Seq(Ast(node)) - case _ => - logger.error(s"astForSingleLeftHandSideContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq.empty - } - - protected def astForSingleAssignmentExpressionContext(ctx: SingleAssignmentExpressionContext): Seq[Ast] = { - val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide) - val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide) - - val operatorName = getOperatorName(ctx.op) - val opCallNode = - callNode(ctx, code(ctx), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH, None, Option(Defines.Any)) - .lineNumber(ctx.op.getLine) - .columnNumber(ctx.op.getCharPositionInLine) - if (leftAst.size == 1 && rightAst.size > 1) { - /* - * This is multiple RHS packed into a single LHS. That is, packing left hand side. - * This is as good as multiple RHS packed into an array and put into a single LHS - */ - val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true) - Seq(callAst(opCallNode, leftAst ++ packedRHS)) - } else { - Seq(callAst(opCallNode, leftAst ++ rightAst)) - } - } - - protected def astForMultipleAssignmentExpressionContext(ctx: MultipleAssignmentExpressionContext): Seq[Ast] = { - val rhsAsts = astForMultipleRightHandSideContext(ctx.multipleRightHandSide()) - val lhsAsts = astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) - val operatorName = getOperatorName(ctx.EQ.getSymbol) - - /* - * This is multiple LHS and multiple RHS - *Since we have multiple LHS and RHS elements here, we will now create synthetic assignment - * call nodes to model how ruby assigns values from RHS elements to LHS elements. We create - * tuples for each assignment and then pass them to the assignment calls nodes - */ - val assigns = - if (lhsAsts.size < rhsAsts.size) { - /* The rightmost AST in the LHS is a packed variable. - * Pack the extra ASTs and the rightmost AST in the RHS in one array like the if() part - */ - val diff = rhsAsts.size - lhsAsts.size - val packedRHS = getPackedRHS(rhsAsts.takeRight(diff + 1)).headOption.to(Seq) - val alignedAsts = lhsAsts.take(lhsAsts.size - 1) zip rhsAsts.take(lhsAsts.size - 1) - val packedAsts = lhsAsts.takeRight(1) zip packedRHS - alignedAsts ++ packedAsts - } else { - lhsAsts.zip(rhsAsts) - } - - assigns.map { case (lhsAst, rhsAst) => - val lhsCode = lhsAst.nodes.collectFirst { case x: AstNodeNew => x.code }.getOrElse("") - val rhsCode = rhsAst.nodes.collectFirst { case x: AstNodeNew => x.code }.getOrElse("") - val code = s"$lhsCode = $rhsCode" - val syntheticCallNode = createOpCall(ctx.EQ, operatorName, code) - - callAst(syntheticCallNode, Seq(lhsAst, rhsAst)) - } - } - - protected def astForIndexingExpressionPrimaryContext(ctx: IndexingExpressionPrimaryContext): Seq[Ast] = { - val lhsExpressionAst = astForPrimaryContext(ctx.primary()) - val rhsExpressionAst = Option(ctx.indexingArguments).map(astForIndexingArgumentsContext).getOrElse(Seq()) - - val operator = lhsExpressionAst.flatMap(_.nodes).collectFirst { case x: NewIdentifier => x } match - case Some(node) if node.name == "Array" => Operators.arrayInitializer - case _ => Operators.indexAccess - - val callNode = createOpCall(ctx.LBRACK, operator, code(ctx)) - Seq(callAst(callNode, lhsExpressionAst ++ rhsExpressionAst)) - - } - - private def getPackedRHS(astsToConcat: Seq[Ast], wrapInBrackets: Boolean = false) = { - val code = astsToConcat - .flatMap(_.nodes) - .collect { case x: AstNodeNew => x.code } - .mkString(", ") - - val callNode = NewCall() - .name(Operators.arrayInitializer) - .methodFullName(Operators.arrayInitializer) - .typeFullName(Defines.Any) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(if (wrapInBrackets) s"[$code]" else code) - Seq(callAst(callNode, astsToConcat)) - } - - def astForStringInterpolationContext(ctx: InterpolatedStringExpressionContext): Seq[Ast] = { - val varAsts = ctx.stringInterpolation.interpolatedStringSequence.asScala - .flatMap(inter => - Seq( - Ast( - callNode( - ctx, - code(inter), - RubyOperators.formattedValue, - RubyOperators.formattedValue, - DispatchTypes.STATIC_DISPATCH, - None, - Option(Defines.Any) - ) - ) - ) ++ - astForStatements(inter.compoundStatement.statements, false, false) - ) - .toSeq - - val literalAsts = ctx - .stringInterpolation() - .DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE() - .asScala - .map(substr => - Ast( - createLiteralNode( - substr.getText, - Defines.String, - List(Defines.String), - Option(substr.lineNumber), - Option(substr.columnNumber) - ) - ) - ) - .toSeq - varAsts ++ literalAsts - } - - // TODO: Return Ast instead of Seq[Ast] - protected def astForStringExpression(ctx: StringExpressionContext): Seq[Ast] = ctx match { - case ctx: SimpleStringExpressionContext => Seq(astForSimpleString(ctx.simpleString)) - case ctx: InterpolatedStringExpressionContext => astForStringInterpolationContext(ctx) - case ctx: ConcatenatedStringExpressionContext => Seq(astForConcatenatedStringExpressions(ctx)) - } - - // Regex interpolation has been modeled just as a set of statements, that suffices to track dataflows - protected def astForRegexInterpolationPrimaryContext(ctx: RegexInterpolationContext): Seq[Ast] = { - val varAsts = ctx - .interpolatedRegexSequence() - .asScala - .flatMap(inter => { - astForStatements(inter.compoundStatement().statements(), false, false) - }) - .toSeq - varAsts - } - - protected def astForSimpleString(ctx: SimpleStringContext): Ast = ctx match { - case ctx: SingleQuotedStringLiteralContext => astForSingleQuotedStringLiteral(ctx) - case ctx: DoubleQuotedStringLiteralContext => astForDoubleQuotedStringLiteral(ctx) - } - - protected def astForConcatenatedStringExpressions(ctx: ConcatenatedStringExpressionContext): Ast = { - val stringExpressionAsts = ctx.stringExpression().asScala.flatMap(astForStringExpression) - val callNode_ = callNode( - ctx, - code(ctx), - RubyOperators.stringConcatenation, - RubyOperators.stringConcatenation, - DispatchTypes.STATIC_DISPATCH - ) - callAst(callNode_, stringExpressionAsts.toSeq) - } - - protected def astForTernaryConditionalOperator(ctx: ConditionalOperatorExpressionContext): Ast = { - val testAst = astForExpressionContext(ctx.expression(0)) - val thenAst = astForExpressionContext(ctx.expression(1)) - val elseAst = astForExpressionContext(ctx.expression(2)) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst) - } - - def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] = - Seq(astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)) - - protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = { - val argsAst = Option(ctx.argumentsWithParentheses()) match - case Some(ctxArgs) => astForArgumentsWithParenthesesContext(ctxArgs) - case None => Seq() - astForSuperCall(ctx, argsAst) - } - - // TODO: Handle the optional block. - // NOTE: `super` is quite complicated semantically speaking. We'll need - // to revisit how to represent them. - protected def astForSuperCall(ctx: ParserRuleContext, arguments: Seq[Ast]): Ast = { - val call = - callNode(ctx, code(ctx), RubyOperators.superKeyword, RubyOperators.superKeyword, DispatchTypes.STATIC_DISPATCH) - callAst(call, arguments.toList) - } - - protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = { - val args = argumentsCtx.map(astForArguments).getOrElse(Seq()) - val call = callNode(ctx, code(ctx), UNRESOLVED_YIELD, UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH) - callAst(call, args) - } - - protected def astForUntilExpression(ctx: UntilExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()).headOption - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: testAst should be negated if it's going to be modelled as a while stmt. - whileAst(testAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForForExpression(ctx: ForExpressionContext): Ast = { - val forVarAst = astForForVariableContext(ctx.forVariable()) - val forExprAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: for X in Y is not properly modelled by while Y - val forRootAst = whileAst(forExprAst.headOption, forBodyAst, Some(text(ctx)), line(ctx), column(ctx)) - forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst) - } - - private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = { - if (ctx.singleLeftHandSide() != null) { - astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) - } else if (ctx.multipleLeftHandSide() != null) { - astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) - } else { - Seq(Ast()) - } - } - - protected def astForWhileExpression(ctx: WhileExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - whileAst(testAst.headOption, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForIfExpression(ctx: IfExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).map(astForElsifClause) - val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq()) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - controlStructureAst(ifNode, testAst.headOption) - .withChildren(thenAst) - .withChildren(elsifAsts.toSeq) - .withChildren(elseAst) - } - - private def astForElsifClause(ctx: ElsifClauseContext): Ast = { - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - controlStructureAst(ifNode, testAst.headOption, bodyAst) - } - - protected def astForVariableReference(ctx: VariableReferenceContext): Ast = ctx match { - case ctx: VariableIdentifierVariableReferenceContext => astForVariableIdentifierHelper(ctx.variableIdentifier()) - case ctx: PseudoVariableIdentifierVariableReferenceContext => - astForPseudoVariableIdentifier(ctx.pseudoVariableIdentifier()) - } - - private def astForPseudoVariableIdentifier(ctx: PseudoVariableIdentifierContext): Ast = ctx match { - case ctx: NilPseudoVariableIdentifierContext => astForNilLiteral(ctx) - case ctx: TruePseudoVariableIdentifierContext => astForTrueLiteral(ctx) - case ctx: FalsePseudoVariableIdentifierContext => astForFalseLiteral(ctx) - case ctx: SelfPseudoVariableIdentifierContext => astForSelfPseudoIdentifier(ctx) - case ctx: FilePseudoVariableIdentifierContext => astForFilePseudoIdentifier(ctx) - case ctx: LinePseudoVariableIdentifierContext => astForLinePseudoIdentifier(ctx) - case ctx: EncodingPseudoVariableIdentifierContext => astForEncodingPseudoIdentifier(ctx) - } - - protected def astForVariableIdentifierHelper( - ctx: VariableIdentifierContext, - definitelyIdentifier: Boolean = false - ): Ast = { - /* - * Preferences - * 1. If definitelyIdentifier is SET, create a identifier node - * 2. If an identifier with the variable name exists within the scope, create a identifier node - * 3. If a method with the variable name exists, create a method node - * 4. Otherwise default to identifier node creation since there is no reason (point 2) to create a call node - */ - - val variableName = code(ctx) - val isSelfFieldAccess = variableName.startsWith("@") - if (isSelfFieldAccess) { - // Very basic field detection - fieldReferences.updateWith(classStack.top) { - case Some(xs) => Option(xs ++ Set(ctx)) - case None => Option(Set(ctx)) - } - val thisNode = createThisIdentifier(ctx) - astForFieldAccess(ctx, thisNode) - } else if (definitelyIdentifier || scope.lookupVariable(variableName).isDefined) { - val node = createIdentifierWithScope(ctx, variableName, variableName, Defines.Any, List(), definitelyIdentifier) - Ast(node) - } else if (methodNameToMethod.contains(variableName)) { - astForCallNode(ctx, variableName) - } else if (ModifierTypes.ALL.contains(variableName.toUpperCase)) { - lastModifier = Option(variableName.toUpperCase) - Ast() - } else if (ctx.GLOBAL_VARIABLE_IDENTIFIER() != null) { - val globalVar = ctx.GLOBAL_VARIABLE_IDENTIFIER().getText - Ast(createIdentifierWithScope(ctx, globalVar, globalVar, Defines.String, List())) - } else { - val node = createIdentifierWithScope(ctx, variableName, variableName, Defines.Any, List()) - Ast(node) - } - } - - protected def astForUnlessExpression(ctx: UnlessExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - val elseAst = - Option(ctx.elseClause()).map(_.compoundStatement()).map(st => astForCompoundStatement(st)).getOrElse(Seq()) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst) - } - - protected def astForQuotedStringExpression(ctx: QuotedStringExpressionContext): Seq[Ast] = ctx match - case ctx: NonExpandedQuotedStringLiteralContext => Seq(astForNonExpandedQuotedString(ctx)) - case _ => - logger.error(s"Translation for ${text(ctx)} not implemented yet") - Seq() - - private def astForNonExpandedQuotedString(ctx: NonExpandedQuotedStringLiteralContext): Ast = { - Ast(literalNode(ctx, code(ctx), getBuiltInType(Defines.String))) - } - - // TODO: handle interpolation - protected def astForQuotedRegexInterpolation(ctx: QuotedRegexInterpolationContext): Seq[Ast] = { - Seq(Ast(literalNode(ctx, code(ctx), Defines.Regexp))) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForFunctionsCreator.scala deleted file mode 100644 index 080dfed33e42..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForFunctionsCreator.scala +++ /dev/null @@ -1,417 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.deprecated.utils.PackageContext -import io.joern.x2cpg.utils.NodeBuilders.newModifierNode -import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines} -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, ModifierTypes} -import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.tree.TerminalNode -import org.slf4j.LoggerFactory - -import scala.collection.mutable -import scala.collection.mutable.ListBuffer -import scala.jdk.CollectionConverters.* - -trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { - this: AstCreator => - - private val logger = LoggerFactory.getLogger(getClass) - - /* - *Fake methods created from yield blocks and their yield calls will have this suffix in their names - */ - protected val YIELD_SUFFIX = "_yield" - - /* - * This is used to mark call nodes created due to yield calls. This is set in their names at creation. - * The appropriate name wrt the names of their actual methods is set later in them. - */ - protected val UNRESOLVED_YIELD = "unresolved_yield" - - /* - * Stack of variable identifiers incorrectly identified as method identifiers - * Each AST contains exactly one call or identifier node - */ - protected val methodNameAsIdentifierStack: mutable.Stack[Ast] = mutable.Stack.empty - protected val methodAliases: mutable.HashMap[String, String] = mutable.HashMap.empty - protected val methodNameToMethod: mutable.HashMap[String, NewMethod] = mutable.HashMap.empty - protected val methodDefInArgument: ListBuffer[Ast] = ListBuffer.empty - protected val methodNamesWithYield: mutable.HashSet[String] = mutable.HashSet.empty - protected val blockMethods: ListBuffer[Ast] = ListBuffer.empty - - /** @return - * the method name if found as an alias, or the given name if not found. - */ - protected def resolveAlias(name: String): String = { - methodAliases.getOrElse(name, name) - } - - protected def astForMethodDefinitionContext(ctx: MethodDefinitionContext): Seq[Ast] = { - val astMethodName = Option(ctx.methodNamePart()) match - case Some(ctxMethodNamePart) => - astForMethodNamePartContext(ctxMethodNamePart) - case None => - astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - val callNode = astMethodName.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] - - // Create thisParameter if this is an instance method - // TODO may need to revisit to make this more robust - - val (methodName, methodFullName) = if (callNode.name == Defines.Initialize) { - (XDefines.ConstructorMethodName, classStack.reverse :+ XDefines.ConstructorMethodName mkString pathSep) - } else { - (callNode.name, classStack.reverse :+ callNode.name mkString pathSep) - } - val newMethodNode = methodNode(ctx, methodName, code(ctx), methodFullName, None, relativeFilename) - .columnNumber(callNode.columnNumber) - .lineNumber(callNode.lineNumber) - - scope.pushNewScope(newMethodNode) - - val astMethodParamSeq = ctx.methodNamePart() match { - case _: SimpleMethodNamePartContext if !classStack.top.endsWith(":program") => - val thisParameterNode = createMethodParameterIn( - "this", - typeFullName = callNode.methodFullName, - lineNumber = callNode.lineNumber, - colNumber = callNode.columnNumber, - index = 0, - order = 0 - ) - Seq(Ast(thisParameterNode)) ++ astForMethodParameterPartContext(ctx.methodParameterPart()) - case _ => astForMethodParameterPartContext(ctx.methodParameterPart()) - } - - Option(ctx.END()).foreach(endNode => newMethodNode.lineNumberEnd(endNode.getSymbol.getLine)) - - callNode.methodFullName(methodFullName) - - val classType = if (classStack.isEmpty) "Standalone" else classStack.top - val classPath = classStack.reverse.toList.mkString(pathSep) - packageContext.packageTable.addPackageMethod(packageContext.moduleName, callNode.name, classPath, classType) - - val astBody = Option(ctx.bodyStatement()) match { - case Some(ctxBodyStmt) => astForBodyStatementContext(ctxBodyStmt, true) - case None => - val expAst = astForExpressionContext(ctx.expression()) - Seq(lastStmtAsReturnAst(ctx, expAst.head, Option(text(ctx.expression())))) - } - - // process yield calls. - astBody - .flatMap(_.nodes.collect { case x: NewCall => x }.filter(_.name == UNRESOLVED_YIELD)) - .foreach { yieldCallNode => - val name = newMethodNode.name - val methodFullName = classStack.reverse :+ callNode.name mkString pathSep - yieldCallNode.name(name + YIELD_SUFFIX) - yieldCallNode.methodFullName(methodFullName + YIELD_SUFFIX) - methodNamesWithYield.add(newMethodNode.name) - /* - * These are calls to the yield block of this method. - * Add this method to the list of yield blocks. - * The add() is idempotent and so adding the same method multiple times makes no difference. - * It just needs to be added at this place so that it gets added iff it has a yield block - */ - } - - val methodRetNode = NewMethodReturn().typeFullName(Defines.Any) - - val modifierNode = lastModifier match { - case Some(modifier) => NewModifier().modifierType(modifier).code(modifier) - case None => NewModifier().modifierType(ModifierTypes.PUBLIC).code(ModifierTypes.PUBLIC) - } - /* - * public/private/protected modifiers are in a separate statement - * TODO find out how they should be used. Need to do this iff it adds any value - */ - if (methodName != XDefines.ConstructorMethodName) { - methodNameToMethod.put(newMethodNode.name, newMethodNode) - } - - /* Before creating ast, we traverse the method params and identifiers and link them*/ - val identifiers = - astBody.flatMap(ast => ast.nodes.filter(_.isInstanceOf[NewIdentifier])).asInstanceOf[Seq[NewIdentifier]] - - val params = astMethodParamSeq - .flatMap(_.nodes.collect { case x: NewMethodParameterIn => x }) - .toList - val locals = scope.createAndLinkLocalNodes(diffGraph, params.map(_.name).toSet) - - params.foreach { param => - identifiers.filter(_.name == param.name).foreach { identifier => - diffGraph.addEdge(identifier, param, EdgeTypes.REF) - } - } - scope.popScope() - - Seq( - methodAst( - newMethodNode, - astMethodParamSeq, - blockAst(blockNode(ctx), locals.map(Ast.apply) ++ astBody.toList), - methodRetNode, - Seq[NewModifier](modifierNode) - ) - ) - } - - private def astForOperatorMethodNameContext(ctx: OperatorMethodNameContext): Seq[Ast] = { - /* - * This is for operator overloading for the class - */ - val name = code(ctx) - val methodFullName = classStack.reverse :+ name mkString pathSep - - val node = callNode(ctx, code(ctx), name, methodFullName, DispatchTypes.STATIC_DISPATCH, None, Option(Defines.Any)) - ctx.children.asScala - .collectFirst { case x: TerminalNode => x } - .foreach(x => node.lineNumber(x.lineNumber).columnNumber(x.columnNumber)) - Seq(callAst(node)) - } - - protected def astForMethodNameContext(ctx: MethodNameContext): Seq[Ast] = { - if (ctx.methodIdentifier() != null) { - astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - } else if (ctx.operatorMethodName() != null) { - astForOperatorMethodNameContext(ctx.operatorMethodName) - } else if (ctx.keyword() != null) { - val node = - callNode(ctx, code(ctx), code(ctx), code(ctx), DispatchTypes.STATIC_DISPATCH, None, Option(Defines.Any)) - ctx.children.asScala - .collectFirst { case x: TerminalNode => x } - .foreach(x => - node.lineNumber(x.lineNumber).columnNumber(x.columnNumber).name(x.getText).methodFullName(x.getText) - ) - Seq(callAst(node)) - } else { - Seq.empty - } - } - - private def astForSingletonMethodNamePartContext(ctx: SingletonMethodNamePartContext): Seq[Ast] = { - val definedMethodNameAst = astForDefinedMethodNameContext(ctx.definedMethodName()) - val singletonObjAst = astForSingletonObjectContext(ctx.singletonObject()) - definedMethodNameAst ++ singletonObjAst - } - - private def astForSingletonObjectContext(ctx: SingletonObjectContext): Seq[Ast] = { - if (ctx.variableIdentifier() != null) { - Seq(astForVariableIdentifierHelper(ctx.variableIdentifier(), true)) - } else if (ctx.pseudoVariableIdentifier() != null) { - Seq(Ast()) - } else if (ctx.expressionOrCommand() != null) { - astForExpressionOrCommand(ctx.expressionOrCommand()) - } else { - Seq.empty - } - } - - private def astForParametersContext(ctx: ParametersContext): Seq[Ast] = { - if (ctx == null) return Seq() - - // the parameterTupleList holds the parameter terminal node and is the parameter a variadic parameter - val parameterTupleList = ctx.parameter().asScala.map { - case procCtx if procCtx.procParameter() != null => - (Option(procCtx.procParameter().LOCAL_VARIABLE_IDENTIFIER()), false) - case optCtx if optCtx.optionalParameter() != null => - (Option(optCtx.optionalParameter().LOCAL_VARIABLE_IDENTIFIER()), false) - case manCtx if manCtx.mandatoryParameter() != null => - (Option(manCtx.mandatoryParameter().LOCAL_VARIABLE_IDENTIFIER()), false) - case arrCtx if arrCtx.arrayParameter() != null => - (Option(arrCtx.arrayParameter().LOCAL_VARIABLE_IDENTIFIER()), arrCtx.arrayParameter().STAR() != null) - case keywordCtx if keywordCtx.keywordParameter() != null => - (Option(keywordCtx.keywordParameter().LOCAL_VARIABLE_IDENTIFIER()), false) - case _ => (None, false) - } - - parameterTupleList.zipWithIndex.map { case (paraTuple, paraIndex) => - paraTuple match - case (Some(paraValue), isVariadic) => - val varSymbol = paraValue.getSymbol - createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any)) - Ast( - createMethodParameterIn( - varSymbol.getText, - lineNumber = Some(varSymbol.getLine), - colNumber = Some(varSymbol.getCharPositionInLine), - order = paraIndex + 1, - index = paraIndex + 1 - ).isVariadic(isVariadic) - ) - case _ => - Ast( - createMethodParameterIn( - getUnusedVariableNames(usedVariableNames, Defines.TempParameter), - order = paraIndex + 1, - index = paraIndex + 1 - ) - ) - }.toList - } - - // TODO: Rewrite for simplicity and take into account more than parameter names. - private def astForMethodParameterPartContext(ctx: MethodParameterPartContext): Seq[Ast] = { - if (ctx == null || ctx.parameters == null) Seq.empty - else astForParametersContext(ctx.parameters) - } - - private def astForDefinedMethodNameContext(ctx: DefinedMethodNameContext): Seq[Ast] = { - Option(ctx.methodName()) match - case Some(methodNameCtx) => astForMethodNameContext(methodNameCtx) - case None => astForAssignmentLikeMethodIdentifierContext(ctx.assignmentLikeMethodIdentifier()) - } - - private def astForAssignmentLikeMethodIdentifierContext(ctx: AssignmentLikeMethodIdentifierContext): Seq[Ast] = { - Seq( - callAst( - callNode(ctx, code(ctx), code(ctx), code(ctx), DispatchTypes.STATIC_DISPATCH, Some(""), Some(Defines.Any)) - ) - ) - } - - private def astForMethodNamePartContext(ctx: MethodNamePartContext): Seq[Ast] = ctx match { - case ctx: SimpleMethodNamePartContext => astForSimpleMethodNamePartContext(ctx) - case ctx: SingletonMethodNamePartContext => astForSingletonMethodNamePartContext(ctx) - case _ => - logger.error(s"astForMethodNamePartContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - private def astForSimpleMethodNamePartContext(ctx: SimpleMethodNamePartContext): Seq[Ast] = - astForDefinedMethodNameContext(ctx.definedMethodName) - - protected def methodForClosureStyleFn(ctx: ParserRuleContext): NewMethod = { - val procMethodName = s"proc_${blockIdCounter.getAndAdd(1)}" - val methodFullName = classStack.reverse :+ procMethodName mkString pathSep - methodNode(ctx, procMethodName, code(ctx), methodFullName, None, relativeFilename) - } - - protected def astForProcDefinitionContext(ctx: ProcDefinitionContext): Seq[Ast] = { - /* - * Model a proc as a method - */ - // Note: For parameters in the Proc definition, an implicit parameter which goes by the name of `this` is added to the cpg - val newMethodNode = methodForClosureStyleFn(ctx) - - scope.pushNewScope(newMethodNode) - - val astMethodParam = astForParametersContext(ctx.parameters()) - val paramNames = astMethodParam.flatMap(_.nodes).collect { case x: NewMethodParameterIn => x.name }.toSet - val astBody = astForCompoundStatement(ctx.block.compoundStatement, true) - val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) - - val methodRetNode = NewMethodReturn() - .typeFullName(Defines.Any) - - val modifiers = newModifierNode(ModifierTypes.PUBLIC) :: newModifierNode(ModifierTypes.LAMBDA) :: Nil - - val methAst = methodAst( - newMethodNode, - astMethodParam, - blockAst(blockNode(ctx), locals ++ astBody.toList), - methodRetNode, - modifiers - ) - blockMethods.addOne(methAst) - - val callArgs = astMethodParam - .flatMap(_.root) - .collect { case x: NewMethodParameterIn => x } - .map(param => Ast(createIdentifierWithScope(ctx, param.name, param.code, Defines.Any, Seq(), true))) - - val procCallNode = - callNode( - ctx, - code(ctx), - newMethodNode.name, - newMethodNode.fullName, - DispatchTypes.STATIC_DISPATCH, - None, - Option(Defines.Any) - ) - - scope.popScope() - - Seq(callAst(procCallNode, callArgs)) - } - - def astForDefinedMethodNameOrSymbolContext(ctx: DefinedMethodNameOrSymbolContext): Seq[Ast] = - if (ctx == null) { - Seq.empty - } else { - if (ctx.definedMethodName() != null) { - astForDefinedMethodNameContext(ctx.definedMethodName()) - } else { - Seq(astForSymbolLiteral(ctx.symbol())) - } - } - - protected def astForBlockFunction( - ctxStmt: StatementsContext, - ctxParam: Option[BlockParameterContext], - blockMethodName: String, - lineStart: Int, - lineEnd: Int, - colStart: Int, - colEnd: Int - ): Seq[Ast] = { - /* - * Model a block as a method - */ - val methodFullName = classStack.reverse :+ blockMethodName mkString pathSep - val newMethodNode = methodNode(ctxStmt, blockMethodName, code(ctxStmt), methodFullName, None, relativeFilename) - .lineNumber(lineStart) - .lineNumberEnd(lineEnd) - .columnNumber(colStart) - .columnNumberEnd(colEnd) - - scope.pushNewScope(newMethodNode) - val astMethodParam = ctxParam.map(astForBlockParameterContext).getOrElse(Seq()) - - val publicModifier = NewModifier().modifierType(ModifierTypes.PUBLIC) - val paramSeq = astMethodParam.flatMap(_.root).map { - /* In majority of cases, node will be an identifier */ - case identifierNode: NewIdentifier => - val param = NewMethodParameterIn() - .name(identifierNode.name) - .code(identifierNode.code) - .typeFullName(identifierNode.typeFullName) - .lineNumber(identifierNode.lineNumber) - .columnNumber(identifierNode.columnNumber) - .dynamicTypeHintFullName(identifierNode.dynamicTypeHintFullName) - Ast(param) - case _: NewCall => - /* TODO: Occasionally, we might encounter a _ call in cases like "do |_, x|" where we should handle this? - * But for now, we just return an empty AST. Keeping this match explicitly here so we come back */ - Ast() - case _ => - Ast() - } - val paramNames = (astMethodParam ++ paramSeq) - .flatMap(_.root) - .collect { - case x: NewMethodParameterIn => x.name - case x: NewIdentifier => x.name - } - .toSet - val astBody = astForStatements(ctxStmt, true) - val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) - val methodRetNode = NewMethodReturn().typeFullName(Defines.Any) - - scope.popScope() - - Seq( - methodAst( - newMethodNode, - paramSeq, - blockAst(blockNode(ctxStmt), locals ++ astBody.toList), - methodRetNode, - Seq(publicModifier) - ) - ) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForHereDocsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForHereDocsCreator.scala deleted file mode 100644 index 64cb8726d8cb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForHereDocsCreator.scala +++ /dev/null @@ -1,74 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.{HereDocArgumentContext, HereDocLiteralContext} -import io.joern.rubysrc2cpg.deprecated.parser.HereDocHandling -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewIdentifier, NewLiteral} - -import scala.collection.immutable.Seq -import scala.collection.mutable - -trait AstForHereDocsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private val hereDocTokens = mutable.Stack[(String, NewLiteral)]() - - protected def astForHereDocLiteral(ctx: HereDocLiteralContext): Ast = { - val delimiter = HereDocHandling.getHereDocDelimiter(ctx.HERE_DOC().getText).getOrElse("") - val hereDoc = ctx.HERE_DOC().getText.replaceFirst("<<[~-]", "") - val hereDocTxt = hereDoc.stripPrefix(delimiter).stripSuffix(delimiter).strip() - val literal = NewLiteral() - .code(hereDocTxt) - .typeFullName(Defines.String) - .lineNumber(line(ctx)) - .columnNumber(column(ctx)) - Ast(literal) - } - - protected def astForHereDocArgument(ctx: HereDocArgumentContext): Seq[Ast] = - HereDocHandling.getHereDocDelimiter(ctx.HERE_DOC_IDENTIFIER().getText) match - case Some(delimiter) => - val literal = NewLiteral() - .code("") // build code from the upcoming statements - .typeFullName(Defines.String) - .lineNumber(line(ctx)) - .columnNumber(column(ctx)) - hereDocTokens.push((delimiter, literal)) - Seq(Ast(literal)) - case None => Seq.empty - - /** Will determine, if we have recently met a here doc initializer, if this statement should be converted to a here - * doc literal or returned as-is. - * @param stmt - * the statement AST. - * @return - * the statement AST or nothing if this is determined to be a here doc body. - */ - protected def scanStmtForHereDoc(stmt: Seq[Ast]): Seq[Ast] = { - if (stmt.nonEmpty && hereDocTokens.nonEmpty) { - val (delimiter, literalNode) = hereDocTokens.head - val stmtAst = stmt.head - val atHereDocInitializer = stmt.flatMap(_.nodes).exists { - case x: NewLiteral => hereDocTokens.exists(_._2 == x) - case _ => false - } - if (atHereDocInitializer) { - // We are at the start of the here doc, do nothing - stmt - } else { - // We are in the middle of the here doc, convert statements to here doc body + look out for delimiter - val txt = stmtAst.root match - case Some(x: NewCall) => x.code - case Some(x: NewIdentifier) => x.code - case _ => "" - - if (txt == delimiter) hereDocTokens.pop() - else literalNode.code(s"${literalNode.code}\n$txt".trim) - Seq.empty[Ast] - } - } else { - stmt - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForPrimitivesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForPrimitivesCreator.scala deleted file mode 100644 index 2821d4d8658e..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForPrimitivesCreator.scala +++ /dev/null @@ -1,100 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.deprecated.passes.Defines.getBuiltInType -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.NewCall -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import org.antlr.v4.runtime.ParserRuleContext - -import scala.jdk.CollectionConverters.CollectionHasAsScala - -trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - protected def astForNilLiteral(ctx: DeprecatedRubyParser.NilPseudoVariableIdentifierContext): Ast = - Ast(literalNode(ctx, code(ctx), Defines.NilClass)) - - protected def astForTrueLiteral(ctx: DeprecatedRubyParser.TruePseudoVariableIdentifierContext): Ast = - Ast(literalNode(ctx, code(ctx), Defines.TrueClass)) - - protected def astForFalseLiteral(ctx: DeprecatedRubyParser.FalsePseudoVariableIdentifierContext): Ast = - Ast(literalNode(ctx, code(ctx), Defines.FalseClass)) - - protected def astForSelfPseudoIdentifier(ctx: DeprecatedRubyParser.SelfPseudoVariableIdentifierContext): Ast = - Ast(createIdentifierWithScope(ctx, code(ctx), code(ctx), Defines.Object)) - - protected def astForFilePseudoIdentifier(ctx: DeprecatedRubyParser.FilePseudoVariableIdentifierContext): Ast = - Ast(createIdentifierWithScope(ctx, code(ctx), code(ctx), getBuiltInType(Defines.String))) - - protected def astForLinePseudoIdentifier(ctx: DeprecatedRubyParser.LinePseudoVariableIdentifierContext): Ast = - Ast(createIdentifierWithScope(ctx, code(ctx), code(ctx), getBuiltInType(Defines.Integer))) - - protected def astForEncodingPseudoIdentifier(ctx: DeprecatedRubyParser.EncodingPseudoVariableIdentifierContext): Ast = - Ast(createIdentifierWithScope(ctx, code(ctx), code(ctx), Defines.Encoding)) - - protected def astForNumericLiteral(ctx: DeprecatedRubyParser.NumericLiteralContext): Ast = { - val numericTypeName = - if (isFloatLiteral(ctx.unsignedNumericLiteral)) getBuiltInType(Defines.Float) else getBuiltInType(Defines.Integer) - Ast(literalNode(ctx, code(ctx), numericTypeName)) - } - - protected def astForSymbolLiteral(ctx: DeprecatedRubyParser.SymbolContext): Ast = - Ast(literalNode(ctx, code(ctx), Defines.Symbol)) - - protected def astForSingleQuotedStringLiteral(ctx: DeprecatedRubyParser.SingleQuotedStringLiteralContext): Ast = - Ast(literalNode(ctx, code(ctx), getBuiltInType(Defines.String))) - - protected def astForDoubleQuotedStringLiteral(ctx: DeprecatedRubyParser.DoubleQuotedStringLiteralContext): Ast = - Ast(literalNode(ctx, code(ctx), getBuiltInType(Defines.String))) - - protected def astForRegularExpressionLiteral(ctx: DeprecatedRubyParser.RegularExpressionLiteralContext): Ast = - Ast(literalNode(ctx, code(ctx), Defines.Regexp)) - - private def isFloatLiteral(ctx: DeprecatedRubyParser.UnsignedNumericLiteralContext): Boolean = - Option(ctx.FLOAT_LITERAL_WITH_EXPONENT).isDefined || Option(ctx.FLOAT_LITERAL_WITHOUT_EXPONENT).isDefined - - // TODO: Return Ast instead of Seq[Ast] - protected def astForArrayLiteral(ctx: ArrayConstructorContext): Seq[Ast] = ctx match - case ctx: BracketedArrayConstructorContext => astForBracketedArrayConstructor(ctx) - case ctx: NonExpandedWordArrayConstructorContext => astForNonExpandedWordArrayConstructor(ctx) - case ctx: NonExpandedSymbolArrayConstructorContext => astForNonExpandedSymbolArrayConstructor(ctx) - - private def astForBracketedArrayConstructor(ctx: BracketedArrayConstructorContext): Seq[Ast] = { - Option(ctx.indexingArguments) - .map(astForIndexingArgumentsContext) - .getOrElse(Seq(astForEmptyArrayInitializer(ctx))) - } - - private def astForEmptyArrayInitializer(ctx: ParserRuleContext): Ast = { - Ast(callNode(ctx, code(ctx), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH)) - } - - private def astForNonExpandedWordArrayConstructor(ctx: NonExpandedWordArrayConstructorContext): Seq[Ast] = { - Option(ctx.nonExpandedArrayElements) - .map(astForNonExpandedArrayElements(_, astForNonExpandedWordArrayElement)) - .getOrElse(Seq(astForEmptyArrayInitializer(ctx))) - } - - private def astForNonExpandedWordArrayElement(ctx: NonExpandedArrayElementContext): Ast = { - Ast(literalNode(ctx, code(ctx), Defines.String, List(Defines.String))) - } - - private def astForNonExpandedSymbolArrayConstructor(ctx: NonExpandedSymbolArrayConstructorContext): Seq[Ast] = { - Option(ctx.nonExpandedArrayElements) - .map(astForNonExpandedArrayElements(_, astForNonExpandedSymbolArrayElement)) - .getOrElse(Seq(astForEmptyArrayInitializer(ctx))) - } - - private def astForNonExpandedArrayElements( - ctx: NonExpandedArrayElementsContext, - astForNonExpandedArrayElement: NonExpandedArrayElementContext => Ast - ): Seq[Ast] = { - ctx.nonExpandedArrayElement.asScala.map(astForNonExpandedArrayElement).toSeq - } - - private def astForNonExpandedSymbolArrayElement(ctx: NonExpandedArrayElementContext): Ast = { - Ast(literalNode(ctx, code(ctx), Defines.Symbol)) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForStatementsCreator.scala deleted file mode 100644 index 925dc9e19720..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForStatementsCreator.scala +++ /dev/null @@ -1,426 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import better.files.File -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.x2cpg.Defines.DynamicCallUnknownFullName -import io.joern.x2cpg.Imports.createImportNodeAndLink -import io.joern.x2cpg.X2Cpg.stripQuotes -import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} -import org.antlr.v4.runtime.ParserRuleContext -import org.slf4j.LoggerFactory - -import scala.jdk.CollectionConverters.CollectionHasAsScala - -trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: ValidationMode) { - this: AstCreator => - - private val logger = LoggerFactory.getLogger(this.getClass) - private val prefixMethods = Set( - "attr_reader", - "attr_writer", - "attr_accessor", - "remove_method", - "public_class_method", - "private_class_method", - "private", - "protected", - "module_function" - ) - - private def astForAliasStatement(ctx: AliasStatementContext): Ast = { - val aliasName = ctx.definedMethodNameOrSymbol(0).getText.substring(1) - val methodName = ctx.definedMethodNameOrSymbol(1).getText.substring(1) - methodAliases.addOne(aliasName, methodName) - Ast() - } - - private def astForUndefStatement(ctx: UndefStatementContext): Ast = { - val undefNames = ctx.definedMethodNameOrSymbol().asScala.flatMap(astForDefinedMethodNameOrSymbolContext).toSeq - val call = callNode(ctx, code(ctx), RubyOperators.undef, RubyOperators.undef, DispatchTypes.STATIC_DISPATCH) - callAst(call, undefNames) - } - - private def astForBeginStatement(ctx: BeginStatementContext): Ast = { - val stmts = Option(ctx.compoundStatement).map(astForCompoundStatement(_)).getOrElse(Seq()) - val blockNode = NewBlock().typeFullName(Defines.Any) - blockAst(blockNode, stmts.toList) - } - - private def astForEndStatement(ctx: EndStatementContext): Ast = { - val stmts = Option(ctx.compoundStatement).map(astForCompoundStatement(_)).getOrElse(Seq()) - val blockNode = NewBlock().typeFullName(Defines.Any) - blockAst(blockNode, stmts.toList) - } - - private def astForModifierStatement(ctx: ModifierStatementContext): Ast = ctx.mod.getType match { - case IF => astForIfModifierStatement(ctx) - case UNLESS => astForUnlessModifierStatement(ctx) - case WHILE => astForWhileModifierStatement(ctx) - case UNTIL => astForUntilModifierStatement(ctx) - case RESCUE => astForRescueModifierStatement(ctx) - } - - private def astForIfModifierStatement(ctx: ModifierStatementContext): Ast = { - val lhs = astForStatement(ctx.statement(0)) - val rhs = astForStatement(ctx.statement(1)).headOption - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - lhs.headOption.flatMap(_.root) match - // If the LHS is a `next` command with a return value, then this if statement is its condition and it becomes a - // `return` - case Some(x: NewControlStructure) if x.code == Defines.ModifierNext && lhs.head.nodes.size > 1 => - val retNode = NewReturn().code(Defines.ModifierNext).lineNumber(x.lineNumber).columnNumber(x.columnNumber) - controlStructureAst(ifNode, rhs, Seq(lhs.head.subTreeCopy(x, replacementNode = Option(retNode)))) - case _ => controlStructureAst(ifNode, rhs, lhs) - } - - private def astForUnlessModifierStatement(ctx: ModifierStatementContext): Ast = { - val lhs = astForStatement(ctx.statement(0)) - val rhs = astForStatement(ctx.statement(1)) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, code(ctx)) - controlStructureAst(ifNode, lhs.headOption, rhs) - } - - private def astForWhileModifierStatement(ctx: ModifierStatementContext): Ast = { - val lhs = astForStatement(ctx.statement(0)) - val rhs = astForStatement(ctx.statement(1)) - whileAst(rhs.headOption, lhs, Some(text(ctx))) - } - - private def astForUntilModifierStatement(ctx: ModifierStatementContext): Ast = { - val lhs = astForStatement(ctx.statement(0)) - val rhs = astForStatement(ctx.statement(1)) - whileAst(rhs.headOption, lhs, Some(text(ctx))) - } - - private def astForRescueModifierStatement(ctx: ModifierStatementContext): Ast = { - val lhs = astForStatement(ctx.statement(0)) - val rhs = astForStatement(ctx.statement(1)) - val throwNode = controlStructureNode(ctx, ControlStructureTypes.THROW, code(ctx)) - controlStructureAst(throwNode, rhs.headOption, lhs) - } - - /** If the last statement is a return, this is returned. If not, then a return node is created. - */ - protected def lastStmtAsReturnAst(ctx: ParserRuleContext, lastStmtAst: Ast, maybeCode: Option[String] = None): Ast = - lastStmtAst.root.collectFirst { case x: NewReturn => x } match - case Some(_) => lastStmtAst - case None => - val code = maybeCode.getOrElse(text(ctx)) - val retNode = returnNode(ctx, code) - lastStmtAst.root match - case Some(method: NewMethod) => returnAst(retNode, Seq(Ast(methodToMethodRef(ctx, method)))) - case _ => returnAst(retNode, Seq(lastStmtAst)) - - protected def astForBodyStatementContext(ctx: BodyStatementContext, isMethodBody: Boolean = false): Seq[Ast] = { - if (ctx.rescueClause.size > 0) Seq(astForRescueClause(ctx)) - else astForCompoundStatement(ctx.compoundStatement(), isMethodBody) - } - - protected def astForCompoundStatement( - ctx: CompoundStatementContext, - isMethodBody: Boolean = false, - canConsiderAsLeaf: Boolean = true - ): Seq[Ast] = { - val stmtAsts = Option(ctx) - .map(_.statements()) - .map(astForStatements(_, isMethodBody, canConsiderAsLeaf)) - .getOrElse(Seq.empty) - if (isMethodBody) { - stmtAsts - } else { - Seq(blockAst(blockNode(ctx), stmtAsts.toList)) - } - } - - /* - * Each statement set can be considered a block. The blocks of a method can be considered to form a hierarchy. - * We can consider the blocks structure as a n-way tree. Leaf blocks are blocks that have no more sub blocks i.e children in the - * hierarchy. The last statement of the block of the method which is the top level/root block i.e. method body should be - * converted into a implicit return. However, if the last statement is a if-else it has sub-blocks/child blocks and the last statement of each leaf block in it - * will have to be converted to a implicit return, unless it is already a implicit return. - * Some sub-blocks are exempt from their last statements being converted to returns. Examples are blocks that are arguments to functions like string interpolation. - * - * isMethodBody => The statement set is the top level block in the method. i.e. the root block - * canConsiderAsLeaf => The statement set can be considered a leaf block. This is set to false by the caller when it is a statement - * set as a part of an expression. Eg. argument in string interpolation. We do not want to construct return nodes out of - * string interpolation arguments. These are exempt blocks for implicit returns. - * blockChildHash => Hash of a block id to any child. Absence of a block in this after all its statements have been processed implies - * that the block is a leaf - * blockIdCounter => A simple counter used to assign an unique id to each block. - */ - protected def astForStatements( - ctx: StatementsContext, - isMethodBody: Boolean = false, - canConsiderAsLeaf: Boolean = true - ): Seq[Ast] = { - - def astsForStmtCtx(stCtx: StatementContext, stmtCount: Int, stmtCounter: Int): Seq[Ast] = { - if (isMethodBody) processingLastMethodStatement.lazySet(stmtCounter == stmtCount) - val stAsts = astForStatement(stCtx) - if (stAsts.nonEmpty && canConsiderAsLeaf && processingLastMethodStatement.get) { - blockChildHash.get(currentBlockId.get) match { - case Some(_) => - // this is a non-leaf block - stAsts - case None => - // this is a leaf block - processingLastMethodStatement.lazySet(!(isMethodBody && stmtCounter == stmtCount)) - Seq(lastStmtAsReturnAst(stCtx, stAsts.head, Option(text(stCtx)))) - } - } else { - stAsts - } - } - - Option(ctx) - .map { ctx => - val stmtCount = ctx.statement.size - val parentBlockId = currentBlockId.get - if (canConsiderAsLeaf) blockChildHash.update(parentBlockId, currentBlockId.get) - currentBlockId.lazySet(blockIdCounter.addAndGet(1)) - - val stmtAsts = Option(ctx) - .map(_.statement) - .map(_.asScala) - .getOrElse(Seq.empty) - .zipWithIndex - .flatMap { case (stmtCtx, idx) => astsForStmtCtx(stmtCtx, stmtCount, idx + 1) } - .toSeq - currentBlockId.lazySet(parentBlockId) - stmtAsts - } - .getOrElse(Seq.empty) - } - - // TODO: return Ast instead of Seq[Ast]. - private def astForStatement(ctx: StatementContext): Seq[Ast] = scanStmtForHereDoc(ctx match { - case ctx: AliasStatementContext => Seq(astForAliasStatement(ctx)) - case ctx: UndefStatementContext => Seq(astForUndefStatement(ctx)) - case ctx: BeginStatementContext => Seq(astForBeginStatement(ctx)) - case ctx: EndStatementContext => Seq(astForEndStatement(ctx)) - case ctx: ModifierStatementContext => Seq(astForModifierStatement(ctx)) - case ctx: ExpressionOrCommandStatementContext => astForExpressionOrCommand(ctx.expressionOrCommand()) - }) - - // TODO: return Ast instead of Seq[Ast] - protected def astForExpressionOrCommand(ctx: ExpressionOrCommandContext): Seq[Ast] = ctx match { - case ctx: InvocationExpressionOrCommandContext => astForInvocationExpressionOrCommandContext(ctx) - case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx)) - case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx)) - case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression()) - case _ => - logger.error(s"astForExpressionOrCommand() $relativeFilename, ${text(ctx)} All contexts mismatched.") - Seq(Ast()) - } - - private def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = { - val exprOrCommandAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val call = callNode(ctx, code(ctx), Operators.not, Operators.not, DispatchTypes.STATIC_DISPATCH) - callAst(call, exprOrCommandAst) - } - - private def astForOrAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = ctx.op.getType match { - case OR => astForOrExpressionOrCommand(ctx) - case AND => astForAndExpressionOrCommand(ctx) - } - - private def astForOrExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = { - val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) - val call = callNode(ctx, code(ctx), Operators.or, Operators.or, DispatchTypes.STATIC_DISPATCH) - callAst(call, argsAst.toList) - } - - private def astForAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = { - val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) - val call = callNode(ctx, code(ctx), Operators.and, Operators.and, DispatchTypes.STATIC_DISPATCH) - callAst(call, argsAst.toList) - } - - private def astForSuperCommand(ctx: SuperCommandContext): Ast = - astForSuperCall(ctx, astForArguments(ctx.argumentsWithoutParentheses().arguments())) - - private def astForYieldCommand(ctx: YieldCommandContext): Ast = - astForYieldCall(ctx, Option(ctx.argumentsWithoutParentheses().arguments())) - - private def astForSimpleMethodCommand(ctx: SimpleMethodCommandContext): Seq[Ast] = { - val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), code(ctx)) - methodIdentifierAsts.headOption.foreach(methodNameAsIdentifierStack.push) - val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - - /* get args without the method def in it */ - val argAstsWithoutMethods = argsAsts.filterNot(_.root.exists(_.isInstanceOf[NewMethod])) - - /* isolate methods from the original args and create identifier ASTs from it */ - val methodDefAsts = argsAsts.filter(_.root.exists(_.isInstanceOf[NewMethod])) - val methodToIdentifierAsts = methodDefAsts.flatMap { - _.nodes.collectFirst { case methodNode: NewMethod => - Ast( - createIdentifierWithScope( - methodNode.name, - methodNode.name, - Defines.Any, - Seq.empty, - methodNode.lineNumber, - methodNode.columnNumber, - definitelyIdentifier = true - ) - ) - } - } - - /* TODO: we add the isolated method defs later on to the parent instead */ - methodDefInArgument.addAll(methodDefAsts) - - val callNodes = methodIdentifierAsts.head.nodes.collect { case x: NewCall => x } - if (callNodes.size == 1) { - val callNode = callNodes.head - if (callNode.name == "require" || callNode.name == "load") { - resolveRequireOrLoadPath(argsAsts, callNode) - } else if (callNode.name == "require_relative") { - resolveRelativePath(filename, argsAsts, callNode) - } else if (prefixMethods.contains(callNode.name)) { - /* we remove the method definition AST from argument and add its corresponding identifier form */ - Seq(callAst(callNode, argAstsWithoutMethods ++ methodToIdentifierAsts)) - } else { - Seq(callAst(callNode, argsAsts)) - } - } else { - argsAsts - } - } - - private def astForMemberAccessCommand(ctx: MemberAccessCommandContext): Seq[Ast] = { - astForMethodNameContext(ctx.methodName).headOption - .flatMap(_.root) - .collectFirst { case x: NewCall => resolveAlias(x.name) } - .map(methodName => - callNode( - ctx, - code(ctx), - methodName, - DynamicCallUnknownFullName, - DispatchTypes.STATIC_DISPATCH, - None, - Option(Defines.Any) - ) - ) match - case Some(newCall) => - val primaryAst = astForPrimaryContext(ctx.primary()) - val argsAst = astForArguments(ctx.argumentsWithoutParentheses().arguments) - primaryAst.headOption - .flatMap(_.root) - .collectFirst { case x: NewMethod => x } - .map { methodNode => - val methodRefNode = methodToMethodRef(ctx, methodNode) - blockMethods.addOne(primaryAst.head) - Seq(callAst(newCall, Seq(Ast(methodRefNode)) ++ argsAst)) - } - .getOrElse(Seq(callAst(newCall, argsAst, primaryAst.headOption))) - case None => Seq.empty - } - - private def methodToMethodRef(ctx: ParserRuleContext, methodNode: NewMethod): NewMethodRef = - methodRefNode(ctx, s"def ${methodNode.name}(...)", methodNode.fullName, Defines.Any) - - protected def astForCommand(ctx: CommandContext): Seq[Ast] = ctx match { - case ctx: YieldCommandContext => Seq(astForYieldCommand(ctx)) - case ctx: SuperCommandContext => Seq(astForSuperCommand(ctx)) - case ctx: SimpleMethodCommandContext => astForSimpleMethodCommand(ctx) - case ctx: MemberAccessCommandContext => astForMemberAccessCommand(ctx) - } - - private def resolveRequireOrLoadPath(argsAst: Seq[Ast], callNode: NewCall): Seq[Ast] = { - val importedNode = argsAst.headOption.map(_.nodes.collect { case x: NewLiteral => x }).getOrElse(Seq.empty) - if (importedNode.size == 1) { - val node = importedNode.head - val pathValue = stripQuotes(node.code) - val result = pathValue match { - case path if File(path).exists => - path - case path if File(s"$path.rb").exists => - s"$path.rb" - case _ => - pathValue - } - packageStack.append(result) - val importNode = createImportNodeAndLink(result, pathValue, Some(callNode), diffGraph) - Seq(callAst(callNode, argsAst), Ast(importNode)) - } else { - Seq(callAst(callNode, argsAst)) - } - } - - protected def resolveRelativePath(currentFile: String, argsAst: Seq[Ast], callNode: NewCall): Seq[Ast] = { - val importedNode = argsAst.head.nodes.collect { case x: NewLiteral => x } - if (importedNode.size == 1) { - val node = importedNode.head - val pathValue = stripQuotes(node.code) - val updatedPath = if (pathValue.endsWith(".rb")) pathValue else s"$pathValue.rb" - - val currentDirectory = File(currentFile).parent - val file = File(currentDirectory, updatedPath) - packageStack.append(file.pathAsString) - val importNode = createImportNodeAndLink(updatedPath, pathValue, Some(callNode), diffGraph) - Seq(callAst(callNode, argsAst), Ast(importNode)) - } else { - Seq(callAst(callNode, argsAst)) - } - } - - protected def astForBlock(ctx: BlockContext, blockMethodName: Option[String] = None): Ast = ctx match - case ctx: DoBlockBlockContext => astForDoBlock(ctx.doBlock(), blockMethodName) - case ctx: BraceBlockBlockContext => astForBraceBlock(ctx.braceBlock(), blockMethodName) - - private def astForBlockHelper( - ctx: ParserRuleContext, - blockParamCtx: Option[BlockParameterContext], - compoundStmtCtx: CompoundStatementContext, - blockMethodName: Option[String] = None - ) = { - blockMethodName match { - case Some(blockMethodName) => - astForBlockFunction( - compoundStmtCtx.statements(), - blockParamCtx, - blockMethodName, - line(compoundStmtCtx).head, - lineEnd(compoundStmtCtx).head, - column(compoundStmtCtx).head, - columnEnd(compoundStmtCtx).head - ).head - case None => - val blockNode_ = blockNode(ctx, code(ctx), Defines.Any) - val blockBodyAst = astForCompoundStatement(compoundStmtCtx) - val blockParamAst = blockParamCtx.flatMap(astForBlockParameterContext) - blockAst(blockNode_, blockBodyAst.toList ++ blockParamAst) - } - } - - protected def astForDoBlock(ctx: DoBlockContext, blockMethodName: Option[String] = None): Ast = { - astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) - } - - private def astForBraceBlock(ctx: BraceBlockContext, blockMethodName: Option[String] = None): Ast = { - astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) - } - - // TODO: This class shouldn't be required and will eventually be phased out. - protected implicit class BlockContextExt(val ctx: BlockContext) { - def compoundStatement: CompoundStatementContext = { - fold(_.bodyStatement.compoundStatement, _.bodyStatement.compoundStatement) - } - - def blockParameter: Option[BlockParameterContext] = { - fold(ctx => Option(ctx.blockParameter()), ctx => Option(ctx.blockParameter())) - } - - private def fold[A](f: DoBlockContext => A, g: BraceBlockContext => A): A = ctx match { - case ctx: DoBlockBlockContext => f(ctx.doBlock()) - case ctx: BraceBlockBlockContext => g(ctx.braceBlock()) - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForTypesCreator.scala deleted file mode 100644 index 5ff0df57e826..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/AstForTypesCreator.scala +++ /dev/null @@ -1,270 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser.* -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.x2cpg.utils.* -import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines} -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{ModifierTypes, NodeTypes, Operators, nodes} -import org.antlr.v4.runtime.ParserRuleContext - -import scala.collection.mutable - -trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - // Maps field references of known types - protected val fieldReferences: mutable.HashMap[String, Set[ParserRuleContext]] = mutable.HashMap.empty - protected val typeDeclNameToTypeDecl: mutable.HashMap[String, NewTypeDecl] = mutable.HashMap.empty - - def astForClassDeclaration(ctx: ClassDefinitionPrimaryContext): Seq[Ast] = { - val className = ctx.className.getOrElse(Defines.Any) - if (className != Defines.Any) { - classStack.push(className) - val fullName = classStack.reverse.mkString(pathSep) - - val bodyAst = astForClassBody(ctx.classDefinition().bodyStatement()).map { ast => - ast.root.foreach { - case node: NewMethod => - node - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(fullName) - case _ => - } - ast - } - - if (classStack.nonEmpty) { - classStack.pop() - } - - val typeDecl = typeDeclNode(ctx, className, fullName, relativeFilename, code(ctx).takeWhile(_ != '\n')) - - // create constructor if not explicitly defined - val hasConstructor = - bodyAst.flatMap(_.root).collect { case x: NewMethod => x.name }.contains(XDefines.ConstructorMethodName) - val defaultConstructor = - if (!hasConstructor) - createDefaultConstructor(ctx, typeDecl, bodyAst.flatMap(_.nodes).collect { case x: NewMember => x }) - else Seq.empty - - typeDeclNameToTypeDecl.put(className, typeDecl) - Seq(Ast(typeDecl).withChildren(defaultConstructor ++ bodyAst)) - } else { - Seq.empty - } - } - - /** If no constructor is explicitly defined, will create a default one. - */ - private def createDefaultConstructor( - ctx: ClassDefinitionPrimaryContext, - typeDecl: NewTypeDecl, - fields: Seq[NewMember] - ): Seq[Ast] = { - val name = XDefines.ConstructorMethodName - val code = Seq(typeDecl.name, name).mkString(pathSep) - val fullName = Seq(typeDecl.fullName, name).mkString(pathSep) - - val constructorNode = - methodNode(ctx, name, code, fullName, None, relativeFilename, Option(typeDecl.label), Option(typeDecl.fullName)) - val thisParam = createMethodParameterIn("this", None, None, typeDecl.fullName) - val params = - thisParam +: fields.map(m => createMethodParameterIn(m.name, None, None, m.typeFullName)) - val assignments = fields.map { m => - val thisNode = createThisIdentifier(ctx) - val lhs = astForFieldAccess(ctx, thisNode) - val paramIdentifier = identifierNode(ctx, m.name, m.name, m.typeFullName) - val refParam = params.find(_.name == m.name).get - astForAssignment(lhs.root.get, paramIdentifier) - .withRefEdge(thisNode, thisParam) - .withRefEdge(paramIdentifier, refParam) - }.toList - val body = blockAst(blockNode(ctx), assignments) - val methodReturn = methodReturnNode(ctx, typeDecl.fullName) - - Seq(methodAst(constructorNode, params.map(Ast.apply(_)), body, methodReturn)) - } - - def astForClassExpression(ctx: ClassDefinitionPrimaryContext): Seq[Ast] = { - // TODO test for this is pending due to lack of understanding to generate an example - val astExprOfCommand = astForExpressionOrCommand(ctx.classDefinition().expressionOrCommand()) - val astBodyStatement = astForBodyStatementContext(ctx.classDefinition().bodyStatement()) - val blockNode = NewBlock() - .code(text(ctx)) - val bodyBlockAst = blockAst(blockNode, astBodyStatement.toList) - astExprOfCommand ++ Seq(bodyBlockAst) - } - - def astForModuleDefinitionPrimaryContext(ctx: ModuleDefinitionPrimaryContext): Seq[Ast] = { - val className = ctx.moduleDefinition().classOrModuleReference().classOrModuleName - - if (className != Defines.Any) { - classStack.push(className) - - val fullName = classStack.reverse.mkString(pathSep) - val namespaceBlock = NewNamespaceBlock() - .name(className) - .fullName(fullName) - .filename(relativeFilename) - - val moduleBodyAst = astInFakeMethod(className, fullName, relativeFilename, ctx) - classStack.pop() - Seq(Ast(namespaceBlock).withChildren(moduleBodyAst)) - } else { - Seq.empty - } - - } - - private def astInFakeMethod( - name: String, - fullName: String, - path: String, - ctx: ModuleDefinitionPrimaryContext - ): Seq[Ast] = { - - val fakeGlobalTypeDecl = NewTypeDecl() - .name(name) - .fullName(fullName) - - val bodyAst = astForClassBody(ctx.moduleDefinition().bodyStatement()) - Seq(Ast(fakeGlobalTypeDecl).withChildren(bodyAst)) - } - - private def getClassNameScopedConstantReferenceContext(ctx: ScopedConstantReferenceContext): String = { - val classTerminalNode = ctx.CONSTANT_IDENTIFIER() - - if (ctx.primary() != null) { - val primaryAst = astForPrimaryContext(ctx.primary()) - val moduleNameNode = primaryAst.head.nodes - .filter(node => node.isInstanceOf[NewIdentifier]) - .head - .asInstanceOf[NewIdentifier] - val moduleName = moduleNameNode.name - moduleName + "." + classTerminalNode.getText - } else { - classTerminalNode.getText - } - } - - def membersFromStatementAsts(ast: Ast): Seq[Ast] = - ast.nodes - .collect { case i: NewIdentifier if i.name.startsWith("@") || i.name.isAllUpperCase => i } - .map { i => - val code = ast.root.collect { case c: NewCall => c.code }.getOrElse(i.name) - val modifierType = i.name match - case x if x.startsWith("@@") => ModifierTypes.STATIC - case x if x.isAllUpperCase => ModifierTypes.FINAL - case _ => ModifierTypes.VIRTUAL - val modifierAst = Ast(NewModifier().modifierType(modifierType)) - Ast( - NewMember() - .code(code) - .name(i.name.replaceAll("@", "")) - .typeFullName(i.typeFullName) - .lineNumber(i.lineNumber) - .columnNumber(i.columnNumber) - ).withChild(modifierAst) - } - .toSeq - - /** Handles body statements differently from [[astForBodyStatementContext]] by noting that method definitions should - * be on the root level and assignments where the LHS starts with @@ should be treated as fields. - */ - private def astForClassBody(ctx: BodyStatementContext): Seq[Ast] = { - val rootStatements = - Option(ctx).map(_.compoundStatement()).map(_.statements()).map(astForStatements(_)).getOrElse(Seq()) - retrieveAndGenerateClassChildren(ctx, rootStatements) - } - - /** As class bodies are not treated much differently to other procedure bodies, we need to retrieve certain components - * that would result in the creation of interprocedural constructs. - * - * TODO: This is pretty hacky and the parser could benefit from more specific tokens - */ - private def retrieveAndGenerateClassChildren(classCtx: BodyStatementContext, rootStatements: Seq[Ast]): Seq[Ast] = { - val (memberLikeStmts, blockStmts) = rootStatements - .flatMap { ast => - ast.root match - case Some(_: NewMethod) => Seq(ast) - case Some(x: NewCall) if x.name == Operators.assignment => Seq(ast) ++ membersFromStatementAsts(ast) - case _ => Seq(ast) - } - .partition(_.root match - case Some(_: NewMethod) => true - case Some(_: NewMember) => true - case _ => false - ) - - val methodStmts = memberLikeStmts.filter(_.root.exists(_.isInstanceOf[NewMethod])) - val memberNodes = memberLikeStmts.flatMap(_.root).collect { case m: NewMember => m } - - val uniqueMemberReferences = - (memberNodes ++ fieldReferences.getOrElse(classStack.top, Set.empty).groupBy(_.getText).map { case (code, ctxs) => - NewMember() - .name(code.replaceAll("@", "")) - .code(code) - .typeFullName(Defines.Any) - }).toList.distinctBy(_.name).map { m => - val modifierType = m.name match - case x if x.startsWith("@@") => ModifierTypes.STATIC - case _ => ModifierTypes.VIRTUAL - val modifierAst = Ast(NewModifier().modifierType(modifierType)) - Ast(m).withChild(modifierAst) - } - - // Create class initialization method to host all field initializers - val classInitMethodAst = if (blockStmts.nonEmpty) { - val classInitFullName = (classStack.reverse :+ XDefines.StaticInitMethodName).mkString(pathSep) - val classInitMethod = methodNode( - classCtx, - XDefines.StaticInitMethodName, - XDefines.StaticInitMethodName, - classInitFullName, - None, - relativeFilename, - Option(NodeTypes.TYPE_DECL), - Option(classStack.reverse.mkString(pathSep)) - ) - val classInitBody = blockAst(blockNode(classCtx), blockStmts.toList) - Seq(methodAst(classInitMethod, Seq.empty, classInitBody, methodReturnNode(classCtx, Defines.Any))) - } else { - Seq.empty - } - - classInitMethodAst ++ uniqueMemberReferences ++ methodStmts - } - - implicit class ClassDefinitionPrimaryContextExt(val ctx: ClassDefinitionPrimaryContext) { - - def hasClassDefinition: Boolean = Option(ctx.classDefinition()).isDefined - - def className: Option[String] = - Option(ctx.classDefinition().classOrModuleReference()) match { - case Some(classOrModuleReferenceCtx) => - Option(classOrModuleReferenceCtx) - .map(_.classOrModuleName) - case None => - // TODO the below is just to avoid crashes. This needs to be implemented properly - None - } - } - - implicit class ClassOrModuleReferenceContextExt(val ctx: ClassOrModuleReferenceContext) { - - def hasScopedConstantReference: Boolean = Option(ctx.scopedConstantReference()).isDefined - - def classOrModuleName: String = - Option(ctx) match { - case Some(ct) => - if (ct.hasScopedConstantReference) - getClassNameScopedConstantReferenceContext(ct.scopedConstantReference()) - else - Option(ct.CONSTANT_IDENTIFIER()).map(_.getText) match { - case Some(className) => className - case None => Defines.Any - } - case None => Defines.Any - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/RubyScope.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/RubyScope.scala deleted file mode 100644 index 773c656bb44b..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/astcreation/RubyScope.scala +++ /dev/null @@ -1,110 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.astcreation - -import io.joern.x2cpg.datastructures.Scope -import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes.{DeclarationNew, NewIdentifier, NewLocal, NewNode} -import overflowdb.BatchedUpdate - -import scala.collection.mutable - -/** Extends the Scope class to help scope variables and create locals. - * - * TODO: Extend this to similarly link parameter nodes (especially `this` node) for consistency. - */ -class RubyScope extends Scope[String, NewIdentifier, NewNode] { - - private type VarMap = Map[String, VarGroup] - private type ScopeNodeType = NewNode - - /** Groups a local node with its referencing identifiers. - */ - private case class VarGroup(local: NewLocal, ids: List[NewIdentifier]) - - /** Links a scope to its variable groupings. - */ - private val scopeToVarMap = mutable.HashMap.empty[ScopeNodeType, VarMap] - - override def addToScope(identifier: String, variable: NewIdentifier): NewNode = { - val scopeNode = super.addToScope(identifier, variable) - stack.headOption.foreach(head => scopeToVarMap.appendIdentifierToVarGroup(head.scopeNode, variable)) - scopeNode - } - - override def popScope(): Option[NewNode] = { - stack.headOption.map(_.scopeNode).foreach(scopeToVarMap.remove) - super.popScope() - } - - /** Will generate local nodes for this scope's variables, excluding those that reference parameters. - * @param paramNames - * the names of parameters. - */ - def createAndLinkLocalNodes( - diffGraph: BatchedUpdate.DiffGraphBuilder, - paramNames: Set[String] = Set.empty - ): List[DeclarationNew] = stack.headOption match - case Some(top) => scopeToVarMap.buildVariableGroupings(top.scopeNode, paramNames ++ Set("this"), diffGraph) - case None => List.empty[DeclarationNew] - - /** @param identifier - * the identifier to count - * @return - * the number of times the given identifier occurs in the immediate scope. - */ - def numVariableReferences(identifier: String): Int = { - stack.map(_.scopeNode).flatMap(scopeToVarMap.get).flatMap(_.get(identifier)).map(_.ids.size).headOption.getOrElse(0) - } - - private implicit class IdentifierExt(node: NewIdentifier) { - - /** Creates a new VarGroup and corresponding NewLocal for the given identifier. - */ - def toNewVarGroup: VarGroup = { - val newLocal = NewLocal() - .name(node.name) - .code(node.name) - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) - .typeFullName(node.typeFullName) - VarGroup(newLocal, List(node)) - } - - } - - private implicit class ScopeExt(scopeMap: mutable.Map[ScopeNodeType, VarMap]) { - - /** Registers the identifier to its corresponding variable grouping in the given scope. - */ - def appendIdentifierToVarGroup(key: ScopeNodeType, identifier: NewIdentifier): Unit = - scopeMap.updateWith(key) { - case Some(varMap: VarMap) => - Some(varMap.updatedWith(identifier.name) { - case Some(varGroup: VarGroup) => Some(varGroup.copy(ids = varGroup.ids :+ identifier)) - case None => Some(identifier.toNewVarGroup) - }) - case None => - Some(Map(identifier.name -> identifier.toNewVarGroup)) - } - - /** Will persist the variable groupings that do not represent parameter nodes and link them with REF edges. - * @return - * the list of persisted local nodes. - */ - def buildVariableGroupings( - key: ScopeNodeType, - paramNames: Set[String], - diffGraph: BatchedUpdate.DiffGraphBuilder - ): List[DeclarationNew] = - scopeMap.get(key) match - case Some(varMap) => - varMap.values - .filterNot { case VarGroup(local, _) => paramNames.contains(local.name) } - .map { case VarGroup(local, ids) => - ids.foreach(id => diffGraph.addEdge(id, local, EdgeTypes.REF)) - local - } - .toList - case None => List.empty[DeclarationNew] - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerBase.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerBase.scala deleted file mode 100644 index f48315811f43..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerBase.scala +++ /dev/null @@ -1,49 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF -import org.antlr.v4.runtime.{CharStream, Lexer, Token} - -/** Aggregates auxiliary features to DeprecatedRubyLexer in a single place. */ -abstract class DeprecatedRubyLexerBase(input: CharStream) - extends Lexer(input) - with RegexLiteralHandling - with InterpolationHandling - with QuotedLiteralHandling - with HereDocHandling { - - /** The previously (non-WS) emitted token (in DEFAULT_CHANNEL.) */ - protected var previousNonWsToken: Option[Token] = None - - /** The previously emitted token (in DEFAULT_CHANNEL.) */ - protected var previousToken: Option[Token] = None - - // Same original behaviour, just updating `previous{NonWs}Token`. - override def nextToken: Token = { - val token: Token = super.nextToken - if (token.getChannel == Token.DEFAULT_CHANNEL && token.getType != WS) { - previousNonWsToken = Some(token) - } - previousToken = Some(token) - token - } - - def previousNonWsTokenTypeOrEOF(): Int = { - previousNonWsToken.map(_.getType).getOrElse(EOF) - } - - def previousTokenTypeOrEOF(): Int = { - previousToken.map(_.getType).getOrElse(EOF) - } - - def isNumericTokenType(tokenType: Int): Boolean = { - val numericTokenTypes = Set( - DECIMAL_INTEGER_LITERAL, - OCTAL_INTEGER_LITERAL, - HEXADECIMAL_INTEGER_LITERAL, - FLOAT_LITERAL_WITHOUT_EXPONENT, - FLOAT_LITERAL_WITH_EXPONENT - ) - numericTokenTypes.contains(tokenType) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerPostProcessor.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerPostProcessor.scala deleted file mode 100644 index 0eba4c655137..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/DeprecatedRubyLexerPostProcessor.scala +++ /dev/null @@ -1,74 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF -import org.antlr.v4.runtime.misc.Pair -import org.antlr.v4.runtime.{CommonToken, ListTokenSource, Token, TokenSource} - -import scala.:: -import scala.jdk.CollectionConverters.* - -/** Simplifies the token stream obtained from `DeprecatedRubyLexer`. - */ -object DeprecatedRubyLexerPostProcessor { - - def apply(tokenSource: TokenSource): ListTokenSource = { - var tokens = tokenSource.toSeq - - tokens = tokens.mergeConsecutive(NON_EXPANDED_LITERAL_CHARACTER, NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE) - tokens = tokens.mergeConsecutive(EXPANDED_LITERAL_CHARACTER, EXPANDED_LITERAL_CHARACTER_SEQUENCE) - tokens = tokens.filterNot(_.is(WS)) - - new ListTokenSource(tokens.asJava) - } -} - -private implicit class TokenSourceExt(val tokenSource: TokenSource) { - - def toSeq: Seq[Token] = Seq.unfold(tokenSource) { tkSrc => - tkSrc.nextToken() match - case tk if tk.is(EOF) => None - case tk => Some((tk, tkSrc)) - } -} - -private implicit class SeqExt[A](val elems: Seq[A]) { - - /** An order-preserving `groupBy` implemented on top of `Seq`. Each sub-sequence ("chain") contains 1+ elements. If a - * chain contains 2+ elements, then all its elements satisfy `p`. Flattening returns the original sequence. - */ - def chains(p: A => Boolean): Seq[Seq[A]] = elems.foldRight(Nil: Seq[Seq[A]]) { (h, t) => - t match - case chain :: chains if chain.exists(p) && p(h) => (h +: chain) +: chains - case _ => Seq(h) +: t - } - - /** Collapses, according to a merging operation `m`, all chains that verify `p`. - */ - def mergeChains(p: A => Boolean, m: Seq[A] => A): Seq[A] = { - elems.chains(p).flatMap(chain => if (chain.exists(p)) Seq(m(chain)) else chain) - } - -} - -private implicit class TokenSeqExt(val tokens: Seq[Token]) { - - def mergeAs(tokenType: Int): Token = { - val startIndex = tokens.head.getStartIndex - val stopIndex = tokens.last.getStopIndex - val tokenSource = tokens.head.getTokenSource - val inputStream = tokens.head.getInputStream - val channel = tokens.head.getChannel - new CommonToken(new Pair(tokenSource, inputStream), tokenType, channel, startIndex, stopIndex) - } - - def mergeConsecutive(oldTokenType: Int, newTokenType: Int): Seq[Token] = { - tokens.mergeChains(_.is(oldTokenType), _.mergeAs(newTokenType)) - } -} - -private implicit class TokenExt(val token: Token) { - - def is(tokenType: Int): Boolean = token.getType == tokenType - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/HereDocHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/HereDocHandling.scala deleted file mode 100644 index c4d6f25991ca..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/HereDocHandling.scala +++ /dev/null @@ -1,35 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import better.files.EOF - -trait HereDocHandling { this: DeprecatedRubyLexerBase => - - /** @see - * Stack - * Overflow - */ - def heredocEndAhead(partialHeredoc: String): Boolean = - if (this.getCharPositionInLine != 0) { - // If the lexer is not at the start of a line, no end-delimiter can be possible - false - } else { - // Get the delimiter - HereDocHandling.getHereDocDelimiter(partialHeredoc) match - case Some(delimiter) if !delimiter.zipWithIndex.exists { case (c, idx) => this._input.LA(idx + 1) != c } => - // If we get to this point, we know there is an end delimiter ahead in the char stream, make - // sure it is followed by a white space (or the EOF). If we don't do this, then "FOOS" would also - // be considered the end for the delimiter "FOO" - val charAfterDelimiter = this._input.LA(delimiter.length + 1) - charAfterDelimiter == EOF || Character.isWhitespace(charAfterDelimiter) - case _ => false - } - -} - -object HereDocHandling { - - def getHereDocDelimiter(hereDoc: String): Option[String] = - hereDoc.split("\r?\n|\r").headOption.map(_.replaceAll("^<<[~-]\\s*", "")) - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/InterpolationHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/InterpolationHandling.scala deleted file mode 100644 index 2ea50b1bd5f9..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/InterpolationHandling.scala +++ /dev/null @@ -1,21 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import scala.collection.mutable - -trait InterpolationHandling { this: DeprecatedRubyLexerBase => - - private val interpolationEndTokenType = mutable.Stack[Int]() - - def pushInterpolationEndTokenType(endTokenType: Int): Unit = { - interpolationEndTokenType.push(endTokenType) - } - - def popInterpolationEndTokenType(): Int = { - interpolationEndTokenType.pop() - } - - def isEndOfInterpolation: Boolean = { - interpolationEndTokenType.nonEmpty - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/QuotedLiteralHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/QuotedLiteralHandling.scala deleted file mode 100644 index 50daf48988f5..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/QuotedLiteralHandling.scala +++ /dev/null @@ -1,45 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import scala.collection.mutable - -trait QuotedLiteralHandling { this: DeprecatedRubyLexerBase => - - private val delimiters = mutable.Stack[Int]() - private val endTokenTypes = mutable.Stack[Int]() - - private def closingDelimiterFor(char: Int): Int = char match - case '(' => ')' - case '[' => ']' - case '{' => '}' - case '<' => '>' - case c => c - - private def currentOpeningDelimiter: Int = delimiters.top - - private def currentClosingDelimiter: Int = closingDelimiterFor(currentOpeningDelimiter) - - private def isOpeningDelimiter(char: Int): Boolean = char == currentOpeningDelimiter - - private def isClosingDelimiter(char: Int): Boolean = char == currentClosingDelimiter - - def pushQuotedDelimiter(char: Int): Unit = delimiters.push(char) - - def popQuotedDelimiter(): Unit = delimiters.pop() - - def pushQuotedEndTokenType(endTokenType: Int): Unit = endTokenTypes.push(endTokenType) - - def popQuotedEndTokenType(): Int = endTokenTypes.pop() - - def consumeQuotedCharAndMaybePopMode(char: Int): Unit = { - if (isClosingDelimiter(char)) { - popQuotedDelimiter() - - if (delimiters.isEmpty) { - setType(endTokenTypes.pop()) - popMode() - } - } else if (isOpeningDelimiter(char)) { - pushQuotedDelimiter(char) - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexLiteralHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexLiteralHandling.scala deleted file mode 100644 index 2e84700c59ba..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexLiteralHandling.scala +++ /dev/null @@ -1,78 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF - -trait RegexLiteralHandling { this: DeprecatedRubyLexerBase => - - /* When encountering '/', we need to decide whether this is a binary operator (e.g. `x / y`) or - * a regular expression delimiter (e.g. `/(eu|us)/`) occurrence. Our approach is to look at the - * previously emitted token and decide accordingly. - */ - private val regexTogglingTokens: Set[Int] = Set( - // When '/' occurs after an opening parenthesis, brace or bracket. - LPAREN, - LCURLY, - LBRACK, - // When '/' occurs after a NL. - NL, - // When '/' occurs after a ','. - COMMA, - // When '/' occurs after a ':'. - COLON, - // When '/' occurs after 'when'. - WHEN, - // When '/' occurs after 'unless'. - UNLESS, - // When '/' occurs after an operator. - EMARK, - EMARKEQ, - EMARKTILDE, - AMP, - AMP2, - AMPDOT, - BAR, - BAR2, - EQ, - EQ2, - EQ3, - CARET, - LTEQGT, - EQTILDE, - GT, - GTEQ, - LT, - LTEQ, - LT2, - GT2, - PLUS, - MINUS, - STAR, - STAR2, - SLASH, - PERCENT, - TILDE, - PLUSAT, - MINUSAT, - ASSIGNMENT_OPERATOR - ) - - /** To be invoked when encountering `/`, deciding if it should emit a `REGULAR_EXPRESSION_START` token. */ - protected def isStartOfRegexLiteral: Boolean = { - val isFirstTokenInTheStream = previousNonWsToken.isEmpty - val isRegexTogglingToken = regexTogglingTokens.contains(previousNonWsTokenTypeOrEOF()) - - isFirstTokenInTheStream || isRegexTogglingToken || isInCommandArgumentPosition - } - - /** Decides if the current `/` is being used as an argument to a command, based on the observation that such literals - * may not start with a WS. E.g. `puts /x/` is valid, but `puts / x/` is not. - */ - private def isInCommandArgumentPosition: Boolean = { - val previousNonWsIsIdentifier = previousNonWsTokenTypeOrEOF() == LOCAL_VARIABLE_IDENTIFIER - val previousIsWs = previousTokenTypeOrEOF() == WS - val nextCharIsWs = _input.LA(1) == ' ' - previousNonWsIsIdentifier && previousIsWs && !nextCharIsWs - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstCreationPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstCreationPass.scala deleted file mode 100644 index 1a93f4b5e9e3..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstCreationPass.scala +++ /dev/null @@ -1,44 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.rubysrc2cpg.Config -import io.joern.rubysrc2cpg.deprecated.astcreation.{AstCreator, ResourceManagedParser} -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyParser -import io.joern.rubysrc2cpg.deprecated.utils.{PackageContext, PackageTable} -import io.joern.x2cpg.SourceFiles -import io.joern.x2cpg.datastructures.Global -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language.* -import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate - -import scala.jdk.CollectionConverters.EnumerationHasAsScala - -class AstCreationPass( - cpg: Cpg, - parsedFiles: List[(String, DeprecatedRubyParser.ProgramContext)], - packageTable: PackageTable, - config: Config -) extends ForkJoinParallelCpgPass[(String, DeprecatedRubyParser.ProgramContext)](cpg) { - - private val logger = LoggerFactory.getLogger(this.getClass) - - override def generateParts(): Array[(String, DeprecatedRubyParser.ProgramContext)] = parsedFiles.toArray - - override def runOnPart( - diffGraph: DiffGraphBuilder, - fileNameAndContext: (String, DeprecatedRubyParser.ProgramContext) - ): Unit = { - val (fileName, context) = fileNameAndContext - try { - diffGraph.absorb( - new AstCreator(fileName, context, PackageContext(fileName, packageTable), cpg.metaData.root.headOption)( - config.schemaValidation - ).createAst() - ) - } catch { - case ex: Exception => - logger.error(s"Error while processing AST for file - $fileName - ", ex) - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstPackagePass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstPackagePass.scala deleted file mode 100644 index 147a7b154daa..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/AstPackagePass.scala +++ /dev/null @@ -1,67 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import better.files.File -import io.joern.rubysrc2cpg.deprecated.astcreation.{AstCreator, ResourceManagedParser} -import io.joern.rubysrc2cpg.deprecated.utils.{PackageContext, PackageTable} -import io.joern.x2cpg.ValidationMode -import io.joern.x2cpg.datastructures.Global -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.passes.ForkJoinParallelCpgPass -import org.slf4j.LoggerFactory - -import scala.util.{Failure, Success, Try} - -class AstPackagePass( - cpg: Cpg, - tempExtDir: String, - parser: ResourceManagedParser, - packageTable: PackageTable, - inputPath: String -)(implicit withSchemaValidation: ValidationMode) - extends ForkJoinParallelCpgPass[String](cpg) { - - private val logger = LoggerFactory.getLogger(getClass) - - override def generateParts(): Array[String] = - getRubyDependenciesFile(inputPath) ++ getRubyDependenciesFile(tempExtDir) - - override def runOnPart(diffGraph: DiffGraphBuilder, filePath: String): Unit = { - parser.parse(filePath) match - case Failure(exception) => logger.warn(s"Could not parse file: $filePath, skipping", exception); - case Success(programCtx) => - Try( - new AstCreator( - filePath, - programCtx, - PackageContext(resolveModuleNameFromPath(filePath), packageTable), - Option(inputPath) - ).createAst() - ) - - } - - private def getRubyDependenciesFile(inputPath: String): Array[String] = { - val currentDir = File(inputPath) - if (currentDir.exists) { - currentDir.listRecursively.filter(_.extension.exists(_ == ".rb")).map(_.path.toString).toArray - } else { - Array.empty - } - } - - private def resolveModuleNameFromPath(path: String): String = { - if (path.contains(tempExtDir)) { - val moduleNameRegex = Seq("gems", "([^", "]+)", "lib", ".*").mkString(java.io.File.separator).r - moduleNameRegex - .findFirstMatchIn(path) - .map(_.group(1)) - .getOrElse("") - .split(java.io.File.separator) - .last - .split("-") - .head - } else { - path - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/Defines.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/Defines.scala deleted file mode 100644 index 9beeaea5ea8a..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/Defines.scala +++ /dev/null @@ -1,39 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.rubysrc2cpg.deprecated.astcreation.GlobalTypes - -object Defines { - val Any: String = "ANY" - val Object: String = "Object" - - val NilClass: String = "NilClass" - val TrueClass: String = "TrueClass" - val FalseClass: String = "FalseClass" - - val Numeric: String = "Numeric" - val Integer: String = "Integer" - val Float: String = "Float" - - val String: String = "String" - val Symbol: String = "Symbol" - - val Array: String = "Array" - val Hash: String = "Hash" - - val Encoding: String = "Encoding" - val Regexp: String = "Regexp" - - // TODO: The following shall be moved out eventually. - val ModifierRedo: String = "redo" - val ModifierRetry: String = "retry" - var ModifierNext: String = "next" - - // For un-named identifiers and parameters - val TempIdentifier = "tmp" - val TempParameter = "param" - - // Constructor method - val Initialize = "initialize" - - def getBuiltInType(typeInString: String) = s"${GlobalTypes.builtinPrefix}.$typeInString" -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyImportResolverPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyImportResolverPass.scala deleted file mode 100644 index 3ca71e605a9f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyImportResolverPass.scala +++ /dev/null @@ -1,117 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import better.files.File -import io.joern.rubysrc2cpg.deprecated.utils.PackageTable -import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.semanticcpg.language.importresolver.* -import io.joern.x2cpg.passes.frontend.XImportResolverPass -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language.* - -import java.io.File as JFile -import java.util.regex.{Matcher, Pattern} -class RubyImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends XImportResolverPass(cpg) { - - private val pathPattern = Pattern.compile("[\"']([\\w/.]+)[\"']") - - override protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit = { - - resolveEntities(importedEntity, importCall, fileName).foreach(x => evaluatedImportToTag(x, importCall, diffGraph)) - } - - private def resolveEntities(expEntity: String, importCall: Call, fileName: String): Set[EvaluatedImport] = { - - // TODO - /* Currently we are considering only case where exposed module are Classes, - and the only way to consume them is by creating a new object as we encounter more cases, - This needs to be handled accordingly - */ - - val expResolvedPath = - if (packageTableInfo.getModule(expEntity).nonEmpty || packageTableInfo.getTypeDecl(expEntity).nonEmpty) - expEntity - else if (expEntity.contains(".")) - getResolvedPath(expEntity, fileName) - else if (cpg.file.name(s".*$expEntity.rb").nonEmpty) - getResolvedPath(s"$expEntity.rb", fileName) - else - expEntity - - // TODO Limited ResolvedMethod exposure for now, will open up after looking at more concrete examples - val finalResolved = { - if ( - packageTableInfo.getModule(expResolvedPath).nonEmpty || packageTableInfo.getTypeDecl(expResolvedPath).nonEmpty - ) { - val importNodesFromTypeDecl = packageTableInfo - .getTypeDecl(expEntity) - .flatMap { typeDeclModel => - Seq( - ResolvedMethod(s"${typeDeclModel.fullName}.${XDefines.ConstructorMethodName}", "new"), - ResolvedTypeDecl(typeDeclModel.fullName) - ) - } - .distinct - - val importNodesFromModule = packageTableInfo.getModule(expEntity).flatMap { moduleModel => - Seq(ResolvedTypeDecl(moduleModel.fullName)) - } - (importNodesFromTypeDecl ++ importNodesFromModule).toSet - } else { - val filePattern = s"${Pattern.quote(expResolvedPath)}\\.?.*" - val resolvedTypeDecls = cpg.typeDecl - .where(_.file.name(filePattern)) - .fullName - .flatMap(fullName => - Seq(ResolvedTypeDecl(fullName), ResolvedMethod(s"$fullName.${XDefines.ConstructorMethodName}", "new")) - ) - .toSet - - val resolvedModules = cpg.namespaceBlock - .whereNot(_.nameExact("")) - .where(_.file.name(filePattern)) - .flatMap(module => Seq(ResolvedTypeDecl(module.fullName))) - .toSet - - // Expose methods which are directly present in a file, without any module, TypeDecl - val resolvedMethods = cpg.method - .where(_.file.name(filePattern)) - .where(_.nameExact(":program")) - .astChildren - .astChildren - .isMethod - .flatMap(method => Seq(ResolvedMethod(method.fullName, method.name))) - .toSet - resolvedTypeDecls ++ resolvedModules ++ resolvedMethods - } - }.collectAll[EvaluatedImport].toSet - - finalResolved - } - - def getResolvedPath(expEntity: String, fileName: String) = { - val rawEntity = expEntity.stripPrefix("./") - val matcher = pathPattern.matcher(rawEntity) - val sep = Matcher.quoteReplacement(JFile.separator) - val root = s"$codeRootDir${JFile.separator}" - val currentFile = s"$root$fileName" - val entity = if (matcher.find()) matcher.group(1) else rawEntity - val resolvedPath = better.files - .File( - currentFile.stripSuffix(currentFile.split(sep).lastOption.getOrElse("")), - entity.split("\\.").headOption.getOrElse(entity) - ) - .pathAsString match { - case resPath if entity.endsWith(".rb") => s"$resPath.rb" - case resPath => resPath - } - resolvedPath.stripPrefix(root) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeHintCallLinker.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeHintCallLinker.scala deleted file mode 100644 index 7c7229fcb893..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeHintCallLinker.scala +++ /dev/null @@ -1,12 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.x2cpg.passes.frontend.XTypeHintCallLinker -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language.* - -class RubyTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { - - override def calls: Iterator[Call] = super.calls.nameNot("^(require).*") - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryPassGenerator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryPassGenerator.scala deleted file mode 100644 index b9dc8c0c80cb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryPassGenerator.scala +++ /dev/null @@ -1,129 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.x2cpg.passes.frontend.* -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate.DiffGraphBuilder -import io.joern.x2cpg.Defines as XDefines - -class RubyTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPassGenerator[File](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState, iteration: Int): XTypeRecovery[File] = - new RubyTypeRecovery(cpg, state, iteration) -} - -private class RubyTypeRecovery(cpg: Cpg, state: XTypeRecoveryState, iteration: Int) - extends XTypeRecovery[File](cpg, state, iteration) { - - override def compilationUnits: Iterator[File] = cpg.file.iterator - - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - new RecoverForRubyFile(cpg, unit, builder, state) - } -} - -private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { - - /** A heuristic method to determine if a call is a constructor or not. - */ - override protected def isConstructor(c: Call): Boolean = { - isConstructor(c.name) && c.code.charAt(0).isUpper - } - - /** A heuristic method to determine if a call name is a constructor or not. - */ - override protected def isConstructor(name: String): Boolean = - !name.isBlank && (name == "new" || name == XDefines.ConstructorMethodName) - - override def visitImport(i: Import): Unit = for { - resolvedImport <- i.call.tag - alias <- i.importedAs - } { - import io.shiftleft.semanticcpg.language.importresolver.* - EvaluatedImport.tagToEvaluatedImport(resolvedImport).foreach { - case ResolvedTypeDecl(fullName, _) => - symbolTable.append(LocalVar(fullName.split("\\.").lastOption.getOrElse(alias)), fullName) - case _ => super.visitImport(i) - } - } - override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { - - def isMatching(cName: String, code: String) = { - val cNameList = cName.split(":program").last.split("\\.").filterNot(_.isEmpty).dropRight(1) - val codeList = code.split("\\(").head.split("[:.]").filterNot(_.isEmpty).dropRight(1) - cNameList sameElements codeList - } - - val constructorPaths = - symbolTable.get(c).filter(isMatching(_, c.code)).map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) - associateTypes(i, constructorPaths) - } - - override def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { - // Check if we have a corresponding member to resolve type - val memberTypes = methodFullNames.flatMap { fullName => - val memberName = fullName.split("\\.").lastOption - if (memberName.isDefined) { - val typeDeclFullName = fullName.stripSuffix(s".${memberName.get}") - cpg.typeDecl.fullName(typeDeclFullName).member.nameExact(memberName.get).typeFullName.l - } else - List.empty - }.toSet - if (memberTypes.nonEmpty) memberTypes else super.methodReturnValues(methodFullNames) - } - - override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { - if (c.name.startsWith("")) { - visitIdentifierAssignedToOperator(i, c, c.name) - } else if (symbolTable.contains(c) && isConstructor(c)) { - visitIdentifierAssignedToConstructor(i, c) - } else if (symbolTable.contains(c)) { - visitIdentifierAssignedToCallRetVal(i, c) - } else if (c.argument.headOption.exists(symbolTable.contains)) { - setCallMethodFullNameFromBase(c) - // Repeat this method now that the call has a type - visitIdentifierAssignedToCall(i, c) - } else if ( - c.argument.headOption - .exists(_.isCall) && c.argument.head - .asInstanceOf[Call] - .name - .equals(".scopeResolution") && c.argument.head - .asInstanceOf[Call] - .argument - .lastOption - .exists(symbolTable.contains) - ) { - setCallMethodFullNameFromBaseScopeResolution(c) - // Repeat this method now that the call has a type - visitIdentifierAssignedToCall(i, c) - } else { - // We can try obtain a return type for this call - visitIdentifierAssignedToCallRetVal(i, c) - } - } - - protected def setCallMethodFullNameFromBaseScopeResolution(c: Call): Set[String] = { - val recTypes = c.argument.headOption - .map { - case x: Call if x.name.equals(".scopeResolution") => - x.argument.lastOption.map(i => symbolTable.get(i)).getOrElse(Set.empty[String]) - } - .getOrElse(Set.empty[String]) - val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) - symbolTable.append(c, callTypes) - } - - override protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef, rec: Option[String]): Set[String] = - t.typ.referencedTypeDecl - .map(_.fullName.stripSuffix("")) - .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) - .headOption - .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/utils/PackageTable.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/utils/PackageTable.scala deleted file mode 100644 index 000cdee6af56..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/deprecated/utils/PackageTable.scala +++ /dev/null @@ -1,88 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.utils - -import java.io.File as JFile -import java.util.regex.Pattern -import scala.collection.mutable - -case class MethodTableModel(methodName: String, parentClassPath: String, classType: String) -case class ModuleModel(name: String, fullName: String) -case class TypeDeclModel(name: String, fullName: String) -case class PackageContext(moduleName: String, packageTable: PackageTable) - -class PackageTable { - - val methodTableMap = mutable.HashMap[String, mutable.HashSet[MethodTableModel]]() - val moduleMapping = mutable.HashMap[String, mutable.HashSet[ModuleModel]]() - val typeDeclMapping = mutable.HashMap[String, mutable.HashSet[TypeDeclModel]]() - - def addPackageMethod(moduleName: String, methodName: String, parentClassPath: String, classType: String): Unit = { - val packageMethod = MethodTableModel(methodName, parentClassPath, classType) - methodTableMap.getOrElseUpdate(moduleName, mutable.HashSet.empty[MethodTableModel]) += packageMethod - } - - def addModule(gemOrFileName: String, moduleName: String, modulePath: String): Unit = { - val fName = gemOrFileName.split(Pattern.quote(JFile.separator)).lastOption.getOrElse(gemOrFileName) - moduleMapping.getOrElseUpdate(gemOrFileName, mutable.HashSet.empty[ModuleModel]) += ModuleModel( - moduleName, - s"$fName::program.$modulePath" - ) - } - - def addTypeDecl(gemOrFileName: String, typeDeclName: String, typeDeclPath: String): Unit = { - val fName = gemOrFileName.split(Pattern.quote(JFile.separator)).lastOption.getOrElse(gemOrFileName) - typeDeclMapping.getOrElseUpdate(gemOrFileName, mutable.HashSet.empty[TypeDeclModel]) += TypeDeclModel( - typeDeclName, - s"$fName::program.$typeDeclPath" - ) - } - - def getMethodFullNameUsingName( - packageUsed: List[String] = List(PackageTable.InternalModule), - methodName: String - ): List[String] = - packageUsed - .filter(methodTableMap.contains) - .flatMap { - case PackageTable.InternalModule => - methodTableMap(PackageTable.InternalModule) - .filter(_.methodName == methodName) - .map(method => s"${method.parentClassPath}.$methodName") - case module => - methodTableMap(module) - .filter(_.methodName == methodName) - .map(method => s"$module::program:${method.parentClassPath}$methodName") - } - - def getPackageInfo(moduleName: String): List[MethodTableModel] = { - methodTableMap.get(moduleName) match - case Some(value) => value.toList - case None => List.empty[MethodTableModel] - } - - def getModule(gemOrFileName: String): List[ModuleModel] = { - moduleMapping.get(gemOrFileName) match - case Some(value) => value.toList - case None => List.empty[ModuleModel] - } - - def getTypeDecl(gemOrFileName: String): List[TypeDeclModel] = { - typeDeclMapping.get(gemOrFileName) match - case Some(value) => value.toList - case None => List.empty[TypeDeclModel] - } - - def set(table: PackageTable): Unit = { - methodTableMap.addAll(table.methodTableMap) - moduleMapping.addAll(table.moduleMapping) - typeDeclMapping.addAll(table.typeDeclMapping) - } - def clear(): Unit = { - methodTableMap.clear - moduleMapping.clear - typeDeclMapping.clear - } -} - -object PackageTable { - val InternalModule = "" -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrContextHelpers.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrContextHelpers.scala deleted file mode 100644 index eb2e0d9ecadc..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrContextHelpers.scala +++ /dev/null @@ -1,224 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.TextSpan -import io.joern.rubysrc2cpg.parser.RubyParser.* -import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.misc.Interval -import org.slf4j.LoggerFactory - -import scala.jdk.CollectionConverters.* - -object AntlrContextHelpers { - - private val logger = LoggerFactory.getLogger(getClass) - - sealed implicit class ParserRuleContextHelper(ctx: ParserRuleContext) { - def toTextSpan: TextSpan = { - // The stopIndex could precede startIndex for rules that do not consume anything, cf. `getStop`. - // We need to make sure this doesn't happen when building the `text` field. - val startIndex = ctx.getStart.getStartIndex - val stopIndex = math.max(startIndex, ctx.getStop.getStopIndex) - TextSpan( - line = Option(ctx.getStart.getLine), - column = Option(ctx.getStart.getCharPositionInLine), - lineEnd = Option(ctx.getStop.getLine), - columnEnd = Option(ctx.getStop.getCharPositionInLine), - text = ctx.getStart.getInputStream.getText(new Interval(startIndex, stopIndex)) - ) - } - } - - sealed implicit class CompoundStatementContextHelper(ctx: CompoundStatementContext) { - def getStatements: List[ParserRuleContext] = - Option(ctx.statements()).map(_.statement().asScala.toList).getOrElse(List()) - } - - sealed implicit class NumericLiteralContextHelper(ctx: NumericLiteralContext) { - def hasSign: Boolean = Option(ctx.sign).isDefined - } - - sealed implicit class SingleOrDoubleQuotedStringContextHelper(ctx: SingleOrDoubleQuotedStringContext) { - def isInterpolated: Boolean = Option(ctx.doubleQuotedString()).exists(_.isInterpolated) - } - - sealed implicit class DoubleQuotedStringContextHelper(ctx: DoubleQuotedStringContext) { - def interpolations: List[ParserRuleContext] = ctx - .doubleQuotedStringContent() - .asScala - .filter(ctx => Option(ctx.compoundStatement()).isDefined) - .map(ctx => ctx.compoundStatement()) - .toList - def isInterpolated: Boolean = interpolations.nonEmpty - } - - sealed implicit class QuotedExpandedStringLiteralContextHelper(ctx: QuotedExpandedStringLiteralContext) { - def interpolations: List[ParserRuleContext] = ctx - .quotedExpandedLiteralStringContent() - .asScala - .filter(ctx => Option(ctx.compoundStatement()).isDefined) - .map(ctx => ctx.compoundStatement()) - .toList - def isInterpolated: Boolean = interpolations.nonEmpty - } - - sealed implicit class DoubleQuotedStringExpressionContextHelper(ctx: DoubleQuotedStringExpressionContext) { - def interpolations: List[ParserRuleContext] = ctx.doubleQuotedString().interpolations ++ ctx - .singleOrDoubleQuotedString() - .asScala - .filter(_.isInterpolated) - .flatMap(_.doubleQuotedString().interpolations) - .toList - - def concatenations: List[SingleOrDoubleQuotedStringContext] = ctx.singleOrDoubleQuotedString().asScala.toList - def isInterpolated: Boolean = ctx.doubleQuotedString().isInterpolated || concatenations.exists(_.isInterpolated) - } - - sealed implicit class DoubleQuotedSymbolExpressionContextHelper(ctx: DoubleQuotedSymbolLiteralContext) { - def interpolations: List[ParserRuleContext] = ctx.doubleQuotedString().interpolations ++ (ctx - .doubleQuotedString() :: Nil) - .filter(_.isInterpolated) - .flatMap(_.interpolations) - - def concatenations: List[DoubleQuotedStringContext] = ctx.doubleQuotedString() :: Nil - def isInterpolated: Boolean = ctx.doubleQuotedString().isInterpolated || concatenations.exists(_.isInterpolated) - } - - sealed implicit class SingleQuotedStringExpressionContextHelper(ctx: SingleQuotedStringExpressionContext) { - def concatenations: List[SingleOrDoubleQuotedStringContext] = ctx.singleOrDoubleQuotedString().asScala.toList - def isInterpolated: Boolean = concatenations.exists(_.isInterpolated) - def interpolations: List[ParserRuleContext] = - concatenations.filter(_.isInterpolated).flatMap(_.doubleQuotedString().interpolations) - } - - sealed implicit class RegularExpressionLiteralContextHelper(ctx: RegularExpressionLiteralContext) { - def isStatic: Boolean = !isDynamic - def isDynamic: Boolean = interpolations.nonEmpty - - def interpolations: List[ParserRuleContext] = ctx - .regexpLiteralContent() - .asScala - .filter(ctx => Option(ctx.compoundStatement()).isDefined) - .map(ctx => ctx.compoundStatement()) - .toList - } - - sealed implicit class QuotedExpandedRegularExpressionLiteralContextHelper( - ctx: QuotedExpandedRegularExpressionLiteralContext - ) { - - def isStatic: Boolean = !isDynamic - def isDynamic: Boolean = interpolations.nonEmpty - - def interpolations: List[ParserRuleContext] = ctx - .quotedExpandedLiteralStringContent() - .asScala - .filter(ctx => Option(ctx.compoundStatement()).isDefined) - .map(ctx => ctx.compoundStatement()) - .toList - - } - - sealed implicit class CurlyBracesBlockContextHelper(ctx: CurlyBracesBlockContext) { - def parameters: List[ParserRuleContext] = Option(ctx.blockParameter()).map(_.parameters).getOrElse(List()) - } - - sealed implicit class BlockParameterContextHelper(ctx: BlockParameterContext) { - def parameters: List[ParserRuleContext] = Option(ctx.parameterList()).map(_.parameters).getOrElse(List()) - } - - sealed implicit class CommandArgumentContextHelper(ctx: CommandArgumentContext) { - def arguments: List[ParserRuleContext] = ctx match { - case ctx: CommandCommandArgumentListContext => ctx.command() :: Nil - case ctx: CommandArgumentCommandArgumentListContext => ctx.commandArgumentList().elements - case ctx => Nil - } - } - - sealed implicit class CommandArgumentListContextHelper(ctx: CommandArgumentListContext) { - def elements: List[ParserRuleContext] = { - val primaryValues = Option(ctx.primaryValueList()).map(_.primaryValue().asScala.toList).getOrElse(List()) - val associations = Option(ctx.associationList()).map(_.association().asScala.toList).getOrElse(List()) - primaryValues ++ associations - } - } - - sealed implicit class ModifierStatementContextHelpers(ctx: ModifierStatementContext) { - def isUnless: Boolean = Option(ctx.statementModifier().UNLESS()).isDefined - def isIf: Boolean = Option(ctx.statementModifier().IF()).isDefined - } - - sealed implicit class QuotedNonExpandedArrayElementListContextHelper(ctx: QuotedNonExpandedArrayElementListContext) { - def elements: List[ParserRuleContext] = ctx.quotedNonExpandedArrayElementContent().asScala.toList - } - - sealed implicit class AssociationListContextHelper(ctx: AssociationListContext) { - def associations: List[ParserRuleContext] = ctx.association().asScala.toList - } - - sealed implicit class MethodIdentifierContextHelper(ctx: MethodIdentifierContext) { - def isAttrDeclaration: Boolean = Set("attr_reader", "attr_writer", "attr_accessor").contains(ctx.getText) - } - - sealed implicit class MandatoryOrOptionalParameterListContextHelper(ctx: MandatoryOrOptionalParameterListContext) { - def parameters: List[ParserRuleContext] = ctx.mandatoryOrOptionalParameter().asScala.toList - } - - sealed implicit class MethodParameterPartContextHelper(ctx: MethodParameterPartContext) { - def parameters: List[ParserRuleContext] = Option(ctx.parameterList()).map(_.parameters).getOrElse(List()) - } - - sealed implicit class ParameterListContextHelper(ctx: ParameterListContext) { - def parameters: List[ParserRuleContext] = { - val mandatoryOrOptionals = Option(ctx.mandatoryOrOptionalParameterList()).map(_.parameters).getOrElse(List()) - val arrayParameter = Option(ctx.arrayParameter()).toList - val hashParameter = Option(ctx.hashParameter()).toList - val procParameter = Option(ctx.procParameter()).toList - mandatoryOrOptionals ++ arrayParameter ++ hashParameter ++ procParameter - } - } - - sealed implicit class IndexingArgumentListContextHelper(ctx: IndexingArgumentListContext) { - def arguments: List[ParserRuleContext] = ctx match - case ctx: CommandIndexingArgumentListContext => List(ctx.command()) - case ctx: OperatorExpressionListIndexingArgumentListContext => - ctx.operatorExpressionList().operatorExpression().asScala.toList - case ctx: AssociationListIndexingArgumentListContext => ctx.associationList().associations - case ctx: SplattingArgumentIndexingArgumentListContext => ctx.splattingArgument() :: Nil - case ctx: OperatorExpressionListWithSplattingArgumentIndexingArgumentListContext => ctx.splattingArgument() :: Nil - case ctx => - logger.warn(s"IndexingArgumentListContextHelper - Unsupported argument type ${ctx.getClass}") - List() - } - - sealed implicit class ArgumentWithParenthesesContextHelper(ctx: ArgumentWithParenthesesContext) { - def arguments: List[ParserRuleContext] = ctx match - case _: EmptyArgumentWithParenthesesContext => List() - case ctx: ArgumentListArgumentWithParenthesesContext => ctx.argumentList().elements - case ctx => - logger.warn(s"ArgumentWithParenthesesContextHelper - Unsupported argument type ${ctx.getClass}") - List() - } - - sealed implicit class ArgumentListContextHelper(ctx: ArgumentListContext) { - def elements: List[ParserRuleContext] = ctx match - case ctx: OperatorsArgumentListContext => - val operatorExpressions = ctx.operatorExpressionList().operatorExpression().asScala.toList - val associations = Option(ctx.associationList()).fold(List())(_.association().asScala) - val splatting = Option(ctx.splattingArgument()).toList - val block = Option(ctx.blockArgument()).toList - operatorExpressions ++ associations ++ splatting ++ block - case ctx: AssociationsArgumentListContext => - Option(ctx.associationList()).map(_.associations).getOrElse(List.empty) - case ctx: SplattingArgumentArgumentListContext => - Option(ctx.splattingArgument()).toList - case ctx: BlockArgumentArgumentListContext => - Option(ctx.blockArgument()).toList - case ctx => - logger.warn(s"ArgumentListContextHelper - Unsupported element type ${ctx.getClass.getSimpleName}") - List() - } - - sealed implicit class CommandWithDoBlockContextHelper(ctx: CommandWithDoBlockContext) { - def arguments: List[ParserRuleContext] = Option(ctx.argumentList()).map(_.elements).getOrElse(Nil) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrParser.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrParser.scala deleted file mode 100644 index 9fa7f2e47511..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AntlrParser.scala +++ /dev/null @@ -1,120 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import better.files.File -import org.antlr.v4.runtime.* -import org.antlr.v4.runtime.atn.{ATN, ATNConfigSet} -import org.antlr.v4.runtime.dfa.DFA -import org.slf4j.LoggerFactory -import java.io.File.separator -import java.util -import scala.collection.mutable.ListBuffer -import scala.util.Try - -/** A consumable wrapper for the RubyParser class used to parse the given file and be disposed thereafter. - * @param inputDir - * the directory of the target to parse. - * @param filename - * the file path to the file to be parsed. - */ -class AntlrParser(inputDir: File, filename: String) { - - private val charStream = CharStreams.fromFileName(filename) - private val lexer = new RubyLexer(charStream) - private val tokenStream = new CommonTokenStream(RubyLexerPostProcessor(lexer)) - val parser: RubyParser = new RubyParser(tokenStream) - - def parse(): (Try[RubyParser.ProgramContext], List[String]) = { - val errors = ListBuffer[String]() - parser.removeErrorListeners() - parser.addErrorListener(new ANTLRErrorListener { - override def syntaxError( - recognizer: Recognizer[?, ?], - offendingSymbol: Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException - ): Unit = { - val errorMessage = - s"Syntax error on ${filename.stripPrefix(s"${inputDir.pathAsString}$separator")}:$line:$charPositionInLine" - errors.append(errorMessage) - } - - override def reportAmbiguity( - recognizer: Parser, - dfa: DFA, - startIndex: Int, - stopIndex: Int, - exact: Boolean, - ambigAlts: util.BitSet, - configs: ATNConfigSet - ): Unit = {} - - override def reportAttemptingFullContext( - recognizer: Parser, - dfa: DFA, - startIndex: Int, - stopIndex: Int, - conflictingAlts: util.BitSet, - configs: ATNConfigSet - ): Unit = {} - - override def reportContextSensitivity( - recognizer: Parser, - dfa: DFA, - startIndex: Int, - stopIndex: Int, - prediction: Int, - configs: ATNConfigSet - ): Unit = {} - }) - (Try(parser.program()), errors.toList) - } -} - -/** A re-usable parser object that clears the ANTLR DFA-cache if it determines that the memory usage is becoming large. - * Once this parser is closed, the whole cache is evicted. - * - * This is done in this way since clearing the cache after each file is inefficient, since the cache must be re-built - * every time, but the cache can become unnecessarily large at times. The cache also does not evict itself at the end - * of parsing. - * - * @param clearLimit - * the percentage of used heap to clear the DFA-cache on. - */ -class ResourceManagedParser(clearLimit: Double) extends AutoCloseable { - - private val logger = LoggerFactory.getLogger(getClass) - private val runtime = Runtime.getRuntime - private var maybeDecisionToDFA: Option[Array[DFA]] = None - private var maybeAtn: Option[ATN] = None - - def parse(inputFile: File, filename: String): Try[RubyParser.ProgramContext] = { - val inputDir = if inputFile.isDirectory then inputFile else inputFile.parent - val antlrParser = AntlrParser(inputDir, filename) - val interp = antlrParser.parser.getInterpreter - // We need to grab a live instance in order to get the static variables as they are protected from static access - maybeDecisionToDFA = Option(interp.decisionToDFA) - maybeAtn = Option(interp.atn) - val usedMemory = runtime.freeMemory.toDouble / runtime.totalMemory.toDouble - if (usedMemory >= clearLimit) { - logger.debug(s"Runtime memory consumption at $usedMemory, clearing ANTLR DFA cache") - clearDFA() - } - val (programCtx, errors) = antlrParser.parse() - errors.foreach(logger.warn) - programCtx - } - - /** Clears the shared DFA cache. - */ - private def clearDFA(): Unit = if (maybeDecisionToDFA.isDefined && maybeAtn.isDefined) { - val decisionToDFA = maybeDecisionToDFA.get - val atn = maybeAtn.get - for (d <- decisionToDFA.indices) { - decisionToDFA(d) = new DFA(atn.getDecisionState(d), d) - } - } - - override def close(): Unit = clearDFA() -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AstPrinter.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AstPrinter.scala deleted file mode 100644 index 1488305c4269..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/AstPrinter.scala +++ /dev/null @@ -1,36 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.tree.TerminalNode - -/** General purpose ANTLR parse tree printer. - */ -object AstPrinter { - private val indentationIncrement = 1 - - private def print(level: Int, sb: StringBuilder, context: ParserRuleContext): StringBuilder = { - val indentation = " ".repeat(level) - val contextName = context.getClass.getSimpleName.stripSuffix("Context") - val nextLevel = level + indentationIncrement - sb.append(s"$indentation$contextName\n") - Option(context.children).foreach(_.forEach { - case c: ParserRuleContext => print(nextLevel, sb, c) - case t: TerminalNode => print(nextLevel, sb, t) - }) - sb - } - - private def print(level: Int, sb: StringBuilder, terminal: TerminalNode): StringBuilder = { - val indentation = " ".repeat(level) - sb.append(s"$indentation${terminal.getText}\n") - sb - } - - /** Pretty-prints an entire `ParserRuleContext` together with its descendants. - * @param context - * the context to pretty-print - * @return - * an indented, multiline string representation - */ - def print(context: ParserRuleContext): String = print(0, new StringBuilder, context).toString() -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/HereDocHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/HereDocHandling.scala deleted file mode 100644 index cb31757a865f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/HereDocHandling.scala +++ /dev/null @@ -1,43 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import better.files.EOF - -trait HereDocHandling { this: RubyLexerBase => - - /** @see - * Stack - * Overflow - */ - def heredocEndAhead(partialHeredoc: String): Boolean = - if (this.getCharPositionInLine != 0) { - // If the lexer is not at the start of a line, no end-delimiter can be possible - false - } else { - // Count WS characters to ignore - var idxWs = 0 - var wsCount = 0 - - while (this._input.LA(idxWs + 1).toChar.isWhitespace) { - wsCount += 1 - idxWs += 1 - } - - // Get the delimiter - HereDocHandling.getHereDocDelimiter(partialHeredoc) match - case Some(delimiter) if !delimiter.zipWithIndex.exists { case (c, idx) => - this._input.LA(idx + wsCount + 1) != c - } => - // If we get to this point, we know there is an end delimiter ahead in the char stream, make - // sure it is followed by a white space (or the EOF). If we don't do this, then "FOOS" would also - // be considered the end for the delimiter "FOO" - val charAfterDelimiter = this._input.LA(delimiter.length + wsCount + 1) - charAfterDelimiter == EOF || Character.isWhitespace(charAfterDelimiter) - case _ => false - } -} - -object HereDocHandling { - def getHereDocDelimiter(hereDoc: String): Option[String] = - hereDoc.split("\r?\n|\r").headOption.map(_.replaceAll("^<<[~-]\\s*", "")) -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/InterpolationHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/InterpolationHandling.scala deleted file mode 100644 index dbcf14c90b3e..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/InterpolationHandling.scala +++ /dev/null @@ -1,21 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import scala.collection.mutable - -trait InterpolationHandling { this: RubyLexerBase => - - private val interpolationEndTokenType = mutable.Stack[Int]() - - def pushInterpolationEndTokenType(endTokenType: Int): Unit = { - interpolationEndTokenType.push(endTokenType) - } - - def popInterpolationEndTokenType(): Int = { - interpolationEndTokenType.pop() - } - - def isEndOfInterpolation: Boolean = { - interpolationEndTokenType.nonEmpty - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/QuotedLiteralHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/QuotedLiteralHandling.scala deleted file mode 100644 index 147ae0dc2df5..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/QuotedLiteralHandling.scala +++ /dev/null @@ -1,45 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import scala.collection.mutable - -trait QuotedLiteralHandling { this: RubyLexerBase => - - private val delimiters = mutable.Stack[Int]() - private val endTokenTypes = mutable.Stack[Int]() - - private def closingDelimiterFor(char: Int): Int = char match - case '(' => ')' - case '[' => ']' - case '{' => '}' - case '<' => '>' - case c => c - - private def currentOpeningDelimiter: Int = delimiters.top - - private def currentClosingDelimiter: Int = closingDelimiterFor(currentOpeningDelimiter) - - private def isOpeningDelimiter(char: Int): Boolean = char == currentOpeningDelimiter - - private def isClosingDelimiter(char: Int): Boolean = char == currentClosingDelimiter - - def pushQuotedDelimiter(char: Int): Unit = delimiters.push(char) - - def popQuotedDelimiter(): Unit = delimiters.pop() - - def pushQuotedEndTokenType(endTokenType: Int): Unit = endTokenTypes.push(endTokenType) - - def popQuotedEndTokenType(): Int = endTokenTypes.pop() - - def consumeQuotedCharAndMaybePopMode(char: Int): Unit = { - if (isClosingDelimiter(char)) { - popQuotedDelimiter() - - if (delimiters.isEmpty) { - setType(endTokenTypes.pop()) - popMode() - } - } else if (isOpeningDelimiter(char)) { - pushQuotedDelimiter(char) - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RegexLiteralHandling.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RegexLiteralHandling.scala deleted file mode 100644 index 5ca3054c7a0c..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RegexLiteralHandling.scala +++ /dev/null @@ -1,86 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import io.joern.rubysrc2cpg.parser.RubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF - -trait RegexLiteralHandling { this: RubyLexerBase => - - /* When encountering '/', we need to decide whether this is a binary operator (e.g. `x / y`) or - * a regular expression delimiter (e.g. `/(eu|us)/`) occurrence. Our approach is to look at the - * previously emitted token and decide accordingly. - */ - private val regexTogglingTokens: Set[Int] = Set( - // When '/' occurs after an opening parenthesis, brace or bracket. - LPAREN, - LCURLY, - LBRACK, - // When '/' occurs after a NL. - NL, - // When '/' occurs after a ','. - COMMA, - // When '/' occurs after a ':'. - COLON, - // When '/' occurs after 'when'. - WHEN, - // When '/' occurs after 'unless'. - UNLESS, - // When '/' occurs after an operator. - EMARK, - EMARKEQ, - EMARKTILDE, - AMP, - AMP2, - AMPDOT, - BAR, - BAR2, - EQ, - EQ2, - EQ3, - CARET, - LTEQGT, - EQGT, - EQTILDE, - GT, - GTEQ, - LT, - LTEQ, - LT2, - GT2, - PLUS, - MINUS, - STAR, - STAR2, - SLASH, - PERCENT, - TILDE, - PLUSAT, - MINUSAT, - ASSIGNMENT_OPERATOR - ) - - /** To be invoked when encountering `/`, deciding if it should emit a `REGULAR_EXPRESSION_START` token. */ - protected def isStartOfRegexLiteral: Boolean = { - val isFirstTokenInTheStream = previousNonWsToken.isEmpty - val isRegexTogglingToken = regexTogglingTokens.contains(previousNonWsTokenTypeOrEOF()) - - isFirstTokenInTheStream || isRegexTogglingToken || isInCommandArgumentPosition - } - - /** Decides if the current `/` is being used as an argument to a command, based on the observation that such literals - * may not start with a WS. E.g. `puts /x/` is valid, but `puts / x/` is not. - */ - private def isInCommandArgumentPosition: Boolean = { - val previousNonWsIsIdentifier = - previousNonWsTokenTypeOrEOF() == LOCAL_VARIABLE_IDENTIFIER || isControlStructureStart - val previousIsWs = previousTokenTypeOrEOF() == WS - val nextCharIsWs = _input.LA(1) == ' ' - previousNonWsIsIdentifier && previousIsWs && !nextCharIsWs - } - - private def isControlStructureStart: Boolean = { - previousNonWsTokenTypeOrEOF() == IF - || previousNonWsTokenTypeOrEOF() == UNLESS - || previousNonWsTokenTypeOrEOF() == UNTIL || previousNonWsTokenTypeOrEOF() == YIELD - || previousNonWsTokenTypeOrEOF() == WHILE - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyAstGenRunner.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyAstGenRunner.scala new file mode 100644 index 000000000000..eaecbe46c3a3 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyAstGenRunner.scala @@ -0,0 +1,216 @@ +package io.joern.rubysrc2cpg.parser + +import better.files.File +import io.joern.rubysrc2cpg.Config +import io.joern.x2cpg.SourceFiles +import io.joern.x2cpg.astgen.AstGenRunner.{AstGenProgramMetaData, AstGenRunnerResult, DefaultAstGenRunnerResult} +import io.joern.x2cpg.astgen.AstGenRunnerBase +import io.joern.x2cpg.utils.{Environment, ExternalCommand} +import org.jruby.RubyInstanceConfig +import org.jruby.embed.{LocalContextScope, LocalVariableBehavior, PathType, ScriptingContainer} +import org.slf4j.LoggerFactory + +import java.io.File.separator +import java.io.{ByteArrayOutputStream, InputStream, PrintStream} +import java.nio.file.{Files, Path, Paths, StandardCopyOption} +import java.util +import java.util.jar.JarFile +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.util.{Failure, Success, Try, Using} + +class RubyAstGenRunner(config: Config) extends AstGenRunnerBase(config) { + + private val logger = LoggerFactory.getLogger(getClass) + + override def fileFilter(file: String, out: File): Boolean = { + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match { + case filePath if isIgnoredByUserConfig(filePath) => false + case filePath if isIgnoredByDefaultRegex(filePath) => false + case filePath if filePath.endsWith(".csproj") => false + case _ => true + } + } + + private def isIgnoredByDefaultRegex(filePath: String): Boolean = { + config.defaultIgnoredFilesRegex.exists(_.matches(filePath)) + } + + override def skippedFiles(in: File, astGenOut: List[String]): List[String] = { + val diagnosticMap = mutable.LinkedHashMap.empty[String, Seq[String]] + + def addReason(reason: String, lastFile: Option[String] = None) = { + val key = lastFile.getOrElse(diagnosticMap.last._1) + diagnosticMap.updateWith(key) { + case Some(x) => Option(x :+ reason) + case None => Option(reason :: Nil) + } + } + + astGenOut.map(_.strip()).foreach { + case s"[WARN] $reason - $fileName" => addReason(reason, Option(fileName)) + case s"[ERR] '$fileName' - $reason" => addReason(reason, Option(fileName)) + case s"[ERR] Failed to parse $fileName: $reason" => + addReason(s"Failed to parse: $reason", Option(fileName)) + case s"[INFO] Processed: $fileName -> $_" => diagnosticMap.put(fileName, Nil) + case s"[INFO] Excluding: $fileName" => addReason("Skipped", Option(fileName)) + case _ => // ignore + } + + diagnosticMap.flatMap { + case (filename, Nil) => + logger.debug(s"Successfully parsed '$filename'") + None + case (filename, "Skipped" :: Nil) => + logger.debug(s"Skipped '$filename' due to file filter") + Option(filename) + case (filename, diagnostics) => + logger.warn( + s"Parsed '$filename' with the following diagnostics:\n${diagnostics.map(x => s" - $x").mkString("\n")}" + ) + Option(filename) + }.toList + } + + override def runAstGenNative(in: String, out: File, exclude: String, include: String)(implicit + metaData: AstGenProgramMetaData + ): Try[Seq[String]] = { + try { + Using.resource(prepareExecutionEnvironment("ruby_ast_gen")) { env => + val cwd = env.path.toAbsolutePath.toString + val excludeCommand = if (exclude.isEmpty) Array.empty[String] else Array("-e", s"$exclude") + val gemPath = Seq(cwd, "vendor", "bundle", "jruby", "3.1.0").mkString(separator) + val rubyArgs = Array("-o", out.toString(), "-i", in).appendedAll(excludeCommand).filterNot(_.isBlank) + val mainScript = Seq(cwd, "exe", "ruby_ast_gen").mkString(separator) + executeWithJRuby(mainScript, cwd, rubyArgs, gemPath) + } + } catch { + case tempPathException: Exception => Failure(tempPathException) + } + } + + private def executeWithJRuby( + mainScript: String, + cwd: String, + rubyArgs: Array[String], + gemPath: String + ): Try[Seq[String]] = { + val outStream = new ByteArrayOutputStream() + val errStream = new ByteArrayOutputStream() + val container = new ScriptingContainer(LocalContextScope.SINGLETHREAD, LocalVariableBehavior.TRANSIENT) + val config = container.getProvider.getRubyInstanceConfig + container.setCompileMode(RubyInstanceConfig.CompileMode.OFF) + container.setNativeEnabled(false) + container.setObjectSpaceEnabled(true) + container.setCurrentDirectory(cwd) + container.setOutput(new PrintStream(outStream)) + container.setError(new PrintStream(errStream)) + config.setLoadGemfile(true) + container.setArgv(rubyArgs) + container.setEnvironment(Map("GEM_PATH" -> gemPath, "GEM_FILE" -> gemPath).asJava) + config.setHasShebangLine(true) + config.setHardExit(false) + + Try { + container.runScriptlet(PathType.ABSOLUTE, mainScript) + outStream.toString.split("\n").toIndexedSeq ++ errStream.toString.split("\n") + } + } + + private def prepareExecutionEnvironment(resourceDir: String): ExecutionEnvironment = { + val resourceUrl = getClass.getClassLoader.getResource(resourceDir) + if (resourceUrl == null) { + throw new IllegalArgumentException(s"Resource sub-directory '$resourceDir' not found.") + } + + resourceUrl.getProtocol match { + case "jar" => + val tempPath = Files.createTempDirectory("ruby_ast_gen-") + val jarPath = resourceUrl.getPath.split("!")(0).stripPrefix("file:") + val jarFile = new JarFile(jarPath) + + val entries = jarFile.entries().asScala.filter(_.getName.startsWith(resourceDir + "/")) + entries.foreach { entry => + val entryPath = tempPath.resolve(entry.getName.stripPrefix(resourceDir + "/")) + if (entry.isDirectory) { + Files.createDirectories(entryPath) + } else { + Files.createDirectories(entryPath.getParent) + val inputStream: InputStream = jarFile.getInputStream(entry) + try { + Files.copy(inputStream, entryPath, StandardCopyOption.REPLACE_EXISTING) + if entryPath.endsWith("ruby_ast_gen") then entryPath.toFile.setExecutable(true, true) + } finally { + inputStream.close() + } + } + } + TempDir(tempPath) + case "file" => + val resourcePath = Paths.get(resourceUrl.toURI) + val mainScript = resourcePath.resolve("exe").resolve("ruby_ast_gen") + mainScript.toFile.setExecutable(true, false) + LocalDir(resourcePath) + case x => + throw new IllegalArgumentException(s"Resources is within an unsupported environment '$x'.") + } + } + + override def execute(out: File): AstGenRunnerResult = { + implicit val metaData: AstGenProgramMetaData = config.astGenMetaData + val in = File(config.inputPath) + logger.info(s"Running ${metaData.name} on '${config.inputPath}'") + + val combineIgnoreRegex = + if (config.ignoredFilesRegex.toString().isEmpty && config.defaultIgnoredFilesRegex.toString.nonEmpty) { + config.defaultIgnoredFilesRegex.mkString("|") + } else if (config.ignoredFilesRegex.toString().nonEmpty && config.defaultIgnoredFilesRegex.toString.isEmpty) { + config.ignoredFilesRegex.toString() + } else if (config.ignoredFilesRegex.toString().nonEmpty && config.defaultIgnoredFilesRegex.toString().nonEmpty) { + s"((${config.ignoredFilesRegex.toString()})|(${config.defaultIgnoredFilesRegex.mkString("|")}))" + } else { + "" + } + + runAstGenNative(config.inputPath, out, combineIgnoreRegex, "") match { + case Success(result) => + val srcFiles = SourceFiles.determine( + out.toString(), + Set(".json"), + ignoredDefaultRegex = Option(config.defaultIgnoredFilesRegex), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + val parsed = filterFiles(srcFiles, out) + val skipped = skippedFiles(in, result.toList) + DefaultAstGenRunnerResult(parsed, skipped) + case Failure(f) => + logger.error(s"\t- running ${metaData.name} failed!", f) + DefaultAstGenRunnerResult() + } + } + + private sealed trait ExecutionEnvironment extends AutoCloseable { + def path: Path + + def close(): Unit = {} + } + + private case class TempDir(path: Path) extends ExecutionEnvironment { + + override def close(): Unit = { + def cleanUpDir(f: Path): Unit = { + if (Files.isDirectory(f)) { + Files.list(f).iterator.asScala.foreach(cleanUpDir) + } + Files.deleteIfExists(f) + } + + cleanUpDir(path) + } + + } + + private case class LocalDir(path: Path) extends ExecutionEnvironment + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonAst.scala new file mode 100644 index 000000000000..32aafa11ac36 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonAst.scala @@ -0,0 +1,209 @@ +package io.joern.rubysrc2cpg.parser + +import io.shiftleft.codepropertygraph.generated.Operators +import upickle.default.* + +/** The JSON key values, in alphabetical order. + */ +object ParserKeys { + val Alias = "alias" + val Arguments = "arguments" + val As = "as" + val Base = "base" + val Body = "body" + val Bodies = "bodies" + val Call = "call" + val CallName = "call_name" + val CaseExpression = "case_expression" + val Children = "children" + val Code = "code" + val Collection = "collection" + val Condition = "condition" + val Conditions = "conditions" + val Def = "def" + val ElseClause = "else_clause" + val ElseBranch = "else_branch" + val End = "end" + val ExecList = "exec_list" + val ExecVar = "exec_var" + val FilePath = "file_path" + val Guard = "guard" + val Key = "key" + val Left = "left" + val Lhs = "lhs" + val MetaData = "meta_data" + val Name = "name" + val Op = "op" + val ParamIdx = "param_idx" + val Pattern = "pattern" + val RelFilePath = "rel_file_path" + val Receiver = "receiver" + val Right = "right" + val Rhs = "rhs" + val Statement = "statement" + val Start = "start" + val SuperClass = "superclass" + val ThenBranch = "then_branch" + val Type = "type" + val Value = "value" + val Values = "values" + val Variable = "variable" + val WhenClauses = "when_clauses" +} + +enum AstType(val name: String) { + case Alias extends AstType("alias") + case And extends AstType("and") + case AndAssign extends AstType("and_asgn") + case Arg extends AstType("arg") + case Args extends AstType("args") + case Array extends AstType("array") + case ArrayPattern extends AstType("array_pattern") + case ArrayPatternWithTail extends AstType("array_pattern_with_tail") + case BackRef extends AstType("back_ref") + case Begin extends AstType("begin") + case Block extends AstType("block") + case BlockArg extends AstType("blockarg") + case BlockPass extends AstType("block_pass") + case BlockWithNumberedParams extends AstType("numblock") + case Break extends AstType("break") + case CaseExpression extends AstType("case") + case CaseMatchStatement extends AstType("case_match") + case ClassDefinition extends AstType("class") + case ClassVariable extends AstType("cvar") + case ClassVariableAssign extends AstType("cvasgn") + case ConstVariableAssign extends AstType("casgn") + case ConditionalSend extends AstType("csend") + case Defined extends AstType("defined?") + case DynamicString extends AstType("dstr") + case DynamicSymbol extends AstType("dsym") + case Ensure extends AstType("ensure") + case ExclusiveFlipFlop extends AstType("eflipflop") + case ExclusiveRange extends AstType("erange") + case ExecutableString extends AstType("xstr") + case False extends AstType("false") + case FindPattern extends AstType("find_pattern") + case Float extends AstType("float") + case ForStatement extends AstType("for") + case ForPostStatement extends AstType("for_post") + case ForwardArg extends AstType("forward_arg") + case ForwardArgs extends AstType("forward_args") + case ForwardedArgs extends AstType("forwarded_args") + case GlobalVariable extends AstType("gvar") + case GlobalVariableAssign extends AstType("gvasgn") + case Hash extends AstType("hash") + case HashPattern extends AstType("hash_pattern") + case Identifier extends AstType("ident") + case IfGuard extends AstType("if_guard") + case IfStatement extends AstType("if") + case InclusiveFlipFlop extends AstType("iflipflop") + case InclusiveRange extends AstType("irange") + case InPattern extends AstType("in_pattern") + case Int extends AstType("int") + case InstanceVariable extends AstType("ivar") + case InstanceVariableAssign extends AstType("ivasgn") + case KwArg extends AstType("kwarg") + case KwBegin extends AstType("kwbegin") + case KwNilArg extends AstType("kwnilarg") + case KwOptArg extends AstType("kwoptarg") + case KwRestArg extends AstType("kwrestarg") + case KwSplat extends AstType("kwsplat") + case LocalVariable extends AstType("lvar") + case LocalVariableAssign extends AstType("lvasgn") + case MatchAlt extends AstType("match_alt") + case MatchAs extends AstType("match_as") + case MatchNilPattern extends AstType("match_nil_pattern") + case MatchPattern extends AstType("match_pattern") + case MatchPatternP extends AstType("match_pattern_p") + case MatchRest extends AstType("match_rest") + case MatchVariable extends AstType("match_var") + case MatchWithLocalVariableAssign extends AstType("match_with_lvasgn") + case MethodDefinition extends AstType("def") + case ModuleDefinition extends AstType("module") + case MultipleAssignment extends AstType("masgn") + case MultipleLeftHandSide extends AstType("mlhs") + case Next extends AstType("next") + case Nil extends AstType("nil") + case NthRef extends AstType("nth_ref") + case OperatorAssign extends AstType("op_asgn") + case OptionalArgument extends AstType("optarg") + case Or extends AstType("or") + case OrAssign extends AstType("or_asgn") + case Pair extends AstType("pair") + case PostExpression extends AstType("postexe") + case PreExpression extends AstType("preexe") + case ProcArgument extends AstType("procarg0") + case Rational extends AstType("rational") + case Redo extends AstType("redo") + case Retry extends AstType("retry") + case Return extends AstType("return") + case RegexExpression extends AstType("regexp") + case RegexOption extends AstType("regopt") + case ResBody extends AstType("resbody") + case RestArg extends AstType("restarg") + case RescueStatement extends AstType("rescue") + case ScopedConstant extends AstType("const") + case Self extends AstType("self") + case Send extends AstType("send") + case ShadowArg extends AstType("shadowarg") + case SingletonMethodDefinition extends AstType("defs") + case SingletonClassDefinition extends AstType("sclass") + case Splat extends AstType("splat") + case StaticString extends AstType("str") + case StaticSymbol extends AstType("sym") + case Super extends AstType("super") + case SuperNoArgs extends AstType("zsuper") + case TopLevelConstant extends AstType("cbase") + case True extends AstType("true") + case UnDefine extends AstType("undef") + case UnlessExpression extends AstType("unless") + case UnlessGuard extends AstType("unless_guard") + case UntilExpression extends AstType("until") + case UntilPostExpression extends AstType("until_post") + case WhenStatement extends AstType("when") + case WhileStatement extends AstType("while") + case WhilePostStatement extends AstType("while_post") + case Yield extends AstType("yield") +} + +object AstType { + def fromString(input: String): Option[AstType] = AstType.values.find(_.name == input) +} + +object BinaryOperators { + private val BinaryOperators: Set[String] = + Set( + "+", + "-", + "*", + "/", + "%", + "**", + "==", + "===", + "!=", + "<", + "<=", + ">", + ">=", + "<=>", + "&&", + "and", + "or", + "||", + "&", + "|", + "^", + // "<<" -> Operators.shiftLeft, Note: Generally Ruby abstracts this as an append operator based on the LHS + ">>" + ) + + def isBinaryOperatorName(op: String): Boolean = BinaryOperators.contains(op) +} + +object UnaryOperators { + private val UnaryOperators: Set[String] = + Set("!", "not", "~", "+", "-") + + def isUnaryOperatorName(op: String): Boolean = UnaryOperators.contains(op) +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonHelpers.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonHelpers.scala new file mode 100644 index 000000000000..913d8ed1e6a3 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonHelpers.scala @@ -0,0 +1,379 @@ +package io.joern.rubysrc2cpg.parser + +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ + AllowedTypeDeclarationChild, + ArrayLiteral, + ClassFieldIdentifier, + DefaultMultipleAssignment, + FieldsDeclaration, + MemberAccess, + MethodDeclaration, + ProcedureDeclaration, + RubyExpression, + RubyFieldIdentifier, + SelfIdentifier, + SimpleCall, + SimpleIdentifier, + SingleAssignment, + SingletonClassDeclaration, + SingletonMethodDeclaration, + SplattingRubyNode, + StatementList, + StaticLiteral, + TextSpan, + TypeDeclBodyCall, + UnaryExpression +} +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.prefixAsCoreType +import org.slf4j.LoggerFactory +import upickle.core.* +import upickle.default.* + +object RubyJsonHelpers { + + private val logger = LoggerFactory.getLogger(getClass) + + implicit class JsonObjHelper(o: ujson.Obj) { + + def toTextSpan: TextSpan = { + val metaData = + if (o.obj.contains(ParserKeys.MetaData)) read[MetaData](o(ParserKeys.MetaData)) + else read[MetaData](o) + + val offset = Option(metaData.offsetStart) -> Option(metaData.offsetEnd) match { + case (Some(start), Some(end)) => Option(start -> end) + case _ => None + } + + TextSpan( + line = Option(metaData.lineNumber).filterNot(_ == -1), + column = Option(metaData.columnNumber).filterNot(_ == -1), + lineEnd = Option(metaData.lineNumberEnd).filterNot(_ == -1), + columnEnd = Option(metaData.columnNumberEnd).filterNot(_ == -1), + offset = offset, + text = metaData.code + ) + } + + def visitOption(key: String)(implicit visit: ujson.Value => RubyExpression): Option[RubyExpression] = + if contains(key) then Option(visit(o(key))) else None + + def visitArray(key: String)(implicit visit: ujson.Value => RubyExpression): List[RubyExpression] = { + o(key).arr.map(visit).toList + } + + def contains(key: String): Boolean = o.obj.get(key).exists(x => x != null && x != ujson.Null) + + } + + protected def nilLiteral(span: TextSpan): StaticLiteral = + StaticLiteral(prefixAsCoreType(Defines.NilClass))(span.spanStart("nil")) + + def createClassBodyAndFields( + obj: ujson.Obj + )(implicit visit: ujson.Value => RubyExpression): (StatementList, List[RubyExpression & RubyFieldIdentifier]) = { + + def bodyMethod(fieldStatements: List[RubyExpression]): MethodDeclaration = { + val body = fieldStatements + .map { + case field: SimpleIdentifier => + val assignmentSpan = field.span.spanStart(s"${field.span.text} = nil") + SingleAssignment(ClassFieldIdentifier()(field.span), "=", nilLiteral(field.span))(assignmentSpan) + case field: RubyFieldIdentifier => + val assignmentSpan = field.span.spanStart(s"${field.span.text} = nil") + SingleAssignment(field, "=", nilLiteral(field.span))(assignmentSpan) + case assignment @ SingleAssignment(_: RubyFieldIdentifier, _, _) => assignment + case assignment @ SingleAssignment(lhs: SimpleIdentifier, _, _) => + assignment.copy(lhs = ClassFieldIdentifier()(lhs.span))(assignment.span) + case otherExpr => otherExpr + } + .distinctBy { + case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => lhs.text + case x => x + } + + MethodDeclaration(Defines.TypeDeclBody, Nil, StatementList(body)(obj.toTextSpan.spanStart(s"(...)")))( + obj.toTextSpan.spanStart(s"def ; (...); end") + ) + } + + /** @param expr + * An expression that is a direct child to a class or module. + * @return + * true if the expression constitutes field-related behaviour, false if otherwise. + */ + def isFieldStmt(expr: RubyExpression): Boolean = { + expr match { + case _: SingleAssignment => true + case _: SimpleIdentifier => true + case _: RubyFieldIdentifier => true + case _ => false + } + } + + /** @param expr + * An expression that is a direct child to a class or module. + * @return + * true if the expression is a Splatting Field Declaration (`attr_x(*foo)`), false otherwise. + */ + def isSplattingField(expr: RubyExpression): Boolean = { + expr match { + case x: FieldsDeclaration if x.isSplattingFieldDecl => true + case _: AllowedTypeDeclarationChild => false + case _ => false + } + } + + /** Extracts a field from the expression. + * @param expr + * An expression that is a direct child to a class or module. + */ + def getFields( + expr: RubyExpression, + typeDeclChildStatements: Boolean = true + ): List[RubyExpression & RubyFieldIdentifier] = { + expr match { + case field: SimpleIdentifier if typeDeclChildStatements => ClassFieldIdentifier()(field.span) :: Nil + case field: RubyFieldIdentifier if typeDeclChildStatements => field :: Nil + case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => lhs :: Nil + case _ @SingleAssignment(lhs: SimpleIdentifier, _, _) if typeDeclChildStatements => + ClassFieldIdentifier()(lhs.span) :: Nil + case proc: ProcedureDeclaration => getFields(proc.body, false) + case _ @StatementList(stmts) => stmts.flatMap(x => getFields(x, typeDeclChildStatements)).distinctBy(_.text) + case _ => Nil + } + } + + /** Attempts to evaluate and parse the collection associated with the splattingField, generating FieldDeclarations + * for each of the elements. + * @param fieldStmts + * List of all the field statements + * @param splattingFields + * List of splatting fields + * @return + * List of: + * - Some(_) => if splattingField either is evaluated to a list of FieldDeclarations, otherwise a SimpleCall + * - None => if splattingField cannot be evaluated to either FieldsDeclaration or SimpleCall + */ + def lowerSplattingFieldDecl( + fieldStmts: List[RubyExpression], + splattingFields: List[RubyExpression] + ): List[Option[RubyExpression]] = { + splattingFields.flatMap { + case x @ FieldsDeclaration(fieldName :: Nil, accessType) if fieldName.isInstanceOf[SplattingRubyNode] => + fieldStmts.map { + case _ @SingleAssignment(lhs: SimpleIdentifier, _, rhs: MemberAccess) + if rhs.memberName == "freeze" && lhs.span.text == fieldName.span.text.stripPrefix("*") => + rhs.target match { + case y: ArrayLiteral => + Some(FieldsDeclaration(y.elements, accessType)(x.span)) + case _ => None + } + case _ @SingleAssignment(_: SimpleIdentifier, _, rhs: ArrayLiteral) => + Some(FieldsDeclaration(rhs.elements, accessType)(x.span)) + case _ => + Some( + SimpleCall(SimpleIdentifier()(x.span.spanStart(accessType)), List(fieldName))( + x.span.spanStart(s"$accessType(${fieldName.span.text})") + ) + ) + } + case _ => None + } + } + + obj.visitOption(ParserKeys.Body).map(lowerSingletonClassDecls) match { + case Some(stmtList @ StatementList(expression :: Nil)) if expression.isInstanceOf[AllowedTypeDeclarationChild] => + if (isSplattingField(expression)) { + val splattingField = expression.asInstanceOf[FieldsDeclaration] + splattingField.fieldNames.headOption match { + case Some(splattingFieldName) => + val nonExpandedSplattingFieldCall = + SimpleCall( + SimpleIdentifier()(expression.span.spanStart(splattingField.accessType)), + List(splattingFieldName) + )(expression.span.spanStart(s"${splattingField.accessType}(${splattingFieldName.span.text})")) + ( + StatementList(bodyMethod(List(nonExpandedSplattingFieldCall)) :: Nil)(stmtList.span), + getFields(expression) + ) + case None => + logger.warn(s"No fieldName found for Splatting Field Decl: ${splattingField.span.text}") + (StatementList(bodyMethod(Nil) :: expression :: Nil)(stmtList.span), getFields(expression)) + } + } else { + (StatementList(bodyMethod(Nil) :: expression :: Nil)(stmtList.span), getFields(expression)) + } + case Some(stmtList @ StatementList(expression :: Nil)) if isFieldStmt(expression) => + (StatementList(bodyMethod(expression :: Nil) :: Nil)(stmtList.span), getFields(expression)) + case Some(stmtList: StatementList) => + val (fieldStmts, otherStmts) = stmtList.statements.partition(isFieldStmt) + val (typeDeclStmts, bodyStmts) = otherStmts.partition(_.isInstanceOf[AllowedTypeDeclarationChild]) + val (splattingFields, otherTypeDeclStmts) = typeDeclStmts.partition(isSplattingField) + val (expandedSplattingFields, nonExpandedSplattingFieldsCalls) = + lowerSplattingFieldDecl(fieldStmts, splattingFields) + .filter(_.isDefined) + .map(_.get) + .partition(_.isInstanceOf[FieldsDeclaration]) + + val fields = + (fieldStmts.flatMap(x => getFields(x)) ++ otherTypeDeclStmts.flatMap(x => getFields(x))) + .distinctBy(_.text) + val body = stmtList.copy(statements = + bodyMethod( + fieldStmts ++ otherTypeDeclStmts.flatMap(x => getFields(x)) ++ bodyStmts ++ nonExpandedSplattingFieldsCalls + ) +: (otherTypeDeclStmts ++ expandedSplattingFields) + )(stmtList.span) + + (body, fields) + case None => (StatementList(bodyMethod(Nil) :: Nil)(obj.toTextSpan.spanStart("")), Nil) + } + } + + def createBodyMemberCall(name: String, textSpan: TextSpan): TypeDeclBodyCall = { + TypeDeclBodyCall( + MemberAccess(SelfIdentifier()(textSpan.spanStart(Defines.Self)), "::", name)( + textSpan.spanStart(s"${Defines.Self}::$name") + ), + name + )(textSpan.spanStart(s"${Defines.Self}::$name::${Defines.TypeDeclBody}")) + } + + def getParts(memberAccess: MemberAccess): List[String] = { + memberAccess.target match { + case targetMemberAccess: MemberAccess => getParts(targetMemberAccess) :+ memberAccess.memberName + case expr => expr.text :: memberAccess.memberName :: Nil + } + } + + def lowerMultipleAssignment( + obj: ujson.Obj, + lhsNodes: List[RubyExpression], + rhsNodes: List[RubyExpression], + defaultResult: () => RubyExpression, + nilResult: () => RubyExpression + ): RubyExpression = { + + /** Recursively expand and duplicate splatting nodes so that they line up with what they consume. + * + * @param nodes + * the splat nodes. + * @param expandSize + * how many more duplicates to create. + */ + def slurp(nodes: List[RubyExpression], expandSize: Int): List[RubyExpression] = nodes match { + case (head: SplattingRubyNode) :: tail if expandSize > 0 => head :: slurp(head :: tail, expandSize - 1) + case head :: tail => head :: slurp(tail, expandSize) + case Nil => List.empty + } + val op = "=" + lazy val defaultAssignments = lhsNodes + .zipAll(rhsNodes, defaultResult(), nilResult()) + .map { case (lhs, rhs) => SingleAssignment(lhs, op, rhs)(obj.toTextSpan) } + + val assignments = if ((lhsNodes ++ rhsNodes).exists(_.isInstanceOf[SplattingRubyNode])) { + rhsNodes.size - lhsNodes.size match { + // Handle slurping the RHS values + case x if x > 0 => { + val slurpedLhs = slurp(lhsNodes, x) + + slurpedLhs + .zip(rhsNodes) + .groupBy(_._1) + .toSeq + .map { case (lhsNode, xs) => lhsNode -> xs.map(_._2) } + .sortBy { x => slurpedLhs.indexOf(x._1) } // groupBy produces a map which discards insertion order + .map { + case (SplattingRubyNode(lhs), rhss) => + SingleAssignment(lhs, op, ArrayLiteral(rhss)(obj.toTextSpan))(obj.toTextSpan) + case (lhs, rhs :: Nil) => SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + case (lhs, rhss) => SingleAssignment(lhs, op, ArrayLiteral(rhss)(obj.toTextSpan))(obj.toTextSpan) + } + .toList + } + // Handle splitting the RHS values + case x if x < 0 => { + val slurpedRhs = slurp(rhsNodes, Math.abs(x)) + + lhsNodes + .zip(slurpedRhs) + .groupBy(_._2) + .toSeq + .map { case (rhsNode, xs) => rhsNode -> xs.map(_._1) } + .sortBy { x => slurpedRhs.indexOf(x._1) } // groupBy produces a map which discards insertion order + .flatMap { + case (SplattingRubyNode(rhs), lhss) => + lhss.map(SingleAssignment(_, op, SplattingRubyNode(rhs)(rhs.span))(obj.toTextSpan)) + case (rhs, lhs :: Nil) => Seq(SingleAssignment(lhs, op, rhs)(obj.toTextSpan)) + case (rhs, lhss) => lhss.map(SingleAssignment(_, op, SplattingRubyNode(rhs)(rhs.span))(obj.toTextSpan)) + } + .toList + } + case _ => defaultAssignments + } + } else { + val diff = rhsNodes.size - lhsNodes.size + if diff < 0 then defaultAssignments.dropRight(Math.abs(diff)) else defaultAssignments + } + DefaultMultipleAssignment(assignments)(obj.toTextSpan) + } + + def infinityUpperBound(obj: ujson.Obj): MemberAccess = + MemberAccess( + SimpleIdentifier(Option(prefixAsCoreType(Defines.Float)))(obj.toTextSpan.spanStart("Float")), + "::", + "INFINITY" + )(obj.toTextSpan.spanStart("Float::INFINITY")) + + def infinityLowerBound(obj: ujson.Obj): UnaryExpression = + UnaryExpression( + "-", + MemberAccess( + SimpleIdentifier(Option(prefixAsCoreType(Defines.Float)))(obj.toTextSpan.spanStart("Float")), + "::", + "INFINITY" + )(obj.toTextSpan.spanStart("Float::INFINITY")) + )(obj.toTextSpan.spanStart("-Float::INFINITY")) + + def lowerSingletonClassDecls(classBody: RubyExpression): StatementList = { + val loweredStmts = classBody match { + case x: StatementList => lowerSingletonClassDeclarations(x) + case x => lowerSingletonClassDeclarations(StatementList(List(x))(x.span)) + } + + val stmts = loweredStmts match { + case StatementList(stmts) => stmts + case x => List(x) + } + + StatementList(stmts)(classBody.span) + } + + private def lowerSingletonClassDeclarations(classBody: RubyExpression): RubyExpression = { + classBody match { + case stmtList: StatementList => + StatementList(stmtList.statements.flatMap { + case _ @SingletonClassDeclaration(_, baseClass: Some[RubyExpression], body: StatementList, _) => + body.statements.map { + case method @ MethodDeclaration(methodName, parameters, body) => + SingletonMethodDeclaration(baseClass.get, methodName, parameters, body)(method.span) + case nonMethodStatement => nonMethodStatement + } + case nonStmtListBody => nonStmtListBody :: Nil + })(stmtList.span) + case nonStmtList => nonStmtList + } + } + + private case class MetaData( + code: String, + @upickle.implicits.key("start_line") lineNumber: Int, + @upickle.implicits.key("start_column") columnNumber: Int, + @upickle.implicits.key("end_line") lineNumberEnd: Int, + @upickle.implicits.key("end_column") columnNumberEnd: Int, + @upickle.implicits.key("offset_start") offsetStart: Int, + @upickle.implicits.key("offset_end") offsetEnd: Int + ) derives ReadWriter + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonParser.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonParser.scala new file mode 100644 index 000000000000..d0cdab59247e --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonParser.scala @@ -0,0 +1,20 @@ +package io.joern.rubysrc2cpg.parser + +import io.joern.x2cpg.astgen.ParserResult +import io.shiftleft.utils.IOUtils + +import java.nio.file.{Path, Paths} + +object RubyJsonParser { + + def readFile(file: Path): ParserResult = { + val jsonContent = IOUtils.readLinesInFile(file).mkString + val json = ujson.read(jsonContent) + val fullFilePath = json(ParserKeys.FilePath).str + val filePath = Paths.get(fullFilePath) + val relFilePath = json(ParserKeys.RelFilePath).str + val sourceFileContent = IOUtils.readEntireFile(filePath) + ParserResult(relFilePath, filePath.toString, json, sourceFileContent) + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonToNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonToNodeCreator.scala new file mode 100644 index 000000000000..c1ee35acf80d --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyJsonToNodeCreator.scala @@ -0,0 +1,1205 @@ +package io.joern.rubysrc2cpg.parser + +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* +import io.joern.rubysrc2cpg.parser.AstType.Send +import io.joern.rubysrc2cpg.parser.RubyJsonHelpers.* +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.{NilClass, RubyOperators} +import io.joern.rubysrc2cpg.passes.GlobalTypes.corePrefix +import io.joern.rubysrc2cpg.utils.FreshNameGenerator +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.ImportsPass +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.ImportsPass.ImportCallNames +import org.slf4j.LoggerFactory +import ujson.* + +class RubyJsonToNodeCreator( + variableNameGen: FreshNameGenerator[String] = FreshNameGenerator(id => s""), + procParamGen: FreshNameGenerator[Left[String, Nothing]] = FreshNameGenerator(id => Left(s"")), + fileName: String = "" +) { + + private val logger = LoggerFactory.getLogger(getClass) + private val classNameGen = FreshNameGenerator(id => s"") + + private implicit val implVisit: ujson.Value => RubyExpression = (x: ujson.Value) => visit(x) + + protected def freshClassName(span: TextSpan): SimpleIdentifier = { + SimpleIdentifier(None)(span.spanStart(classNameGen.fresh)) + } + + private def defaultTextSpan(code: String = ""): TextSpan = TextSpan(None, None, None, None, None, code) + + private def defaultResult(span: Option[TextSpan] = None): RubyExpression = + Unknown()(span.getOrElse(defaultTextSpan())) + + private def visit(v: ujson.Value): RubyExpression = { + v match { + case obj: ujson.Obj => visit(obj) + case ujson.Null => StatementList(Nil)(defaultTextSpan()) + case ujson.Str(x) => StaticLiteral(Defines.prefixAsCoreType(Defines.String))(defaultTextSpan(x)) + case x => + logger.warn(s"Unhandled ujson type ${x.getClass}") + defaultResult() + } + } + + /** Main entrypoint of JSON deserialization. + */ + def visitProgram(obj: ujson.Value): StatementList = { + visit(obj.obj) match { + case x: StatementList => x + case x => StatementList(x :: Nil)(x.span) + } + } + + private def visit(obj: ujson.Obj): RubyExpression = { + + def visitAstType(typ: AstType): RubyExpression = { + typ match { + case AstType.Alias => visitAlias(obj) + case AstType.And => visitAnd(obj) + case AstType.AndAssign => visitAndAssign(obj) + case AstType.Arg => visitArg(obj) + case AstType.Args => visitArgs(obj) + case AstType.Array => visitArray(obj) + case AstType.ArrayPattern => visitArrayPattern(obj) + case AstType.ArrayPatternWithTail => visitArrayPatternWithTail(obj) + case AstType.BackRef => visitBackRef(obj) + case AstType.Begin => visitBegin(obj) + case AstType.Block => visitBlock(obj) + case AstType.BlockArg => visitBlockArg(obj) + case AstType.BlockPass => visitBlockPass(obj) + case AstType.BlockWithNumberedParams => visitBlockWithNumberedParams(obj) + case AstType.Break => visitBreak(obj) + case AstType.CaseExpression => visitCaseExpression(obj) + case AstType.CaseMatchStatement => visitCaseMatchStatement(obj) + case AstType.ClassDefinition => visitClassDefinition(obj) + case AstType.ClassVariable => visitClassVariable(obj) + case AstType.ClassVariableAssign => visitSingleAssignment(obj) + case AstType.ConstVariableAssign => visitSingleAssignment(obj) + case AstType.ConditionalSend => visitSend(obj, isConditional = true) + case AstType.Defined => visitDefined(obj) + case AstType.DynamicString => visitDynamicString(obj) + case AstType.DynamicSymbol => visitDynamicSymbol(obj) + case AstType.Ensure => visitEnsure(obj) + case AstType.ExclusiveFlipFlop => visitExclusiveFlipFlop(obj) + case AstType.ExclusiveRange => visitExclusiveRange(obj) + case AstType.ExecutableString => visitExecutableString(obj) + case AstType.False => visitFalse(obj) + case AstType.FindPattern => visitFindPattern(obj) + case AstType.Float => visitFloat(obj) + case AstType.ForStatement => visitForStatement(obj) + case AstType.ForPostStatement => visitForStatement(obj) + case AstType.ForwardArg => visitForwardArg(obj) + case AstType.ForwardArgs => visitForwardArgs(obj) + case AstType.ForwardedArgs => visitForwardedArgs(obj) + case AstType.GlobalVariable => visitGlobalVariable(obj) + case AstType.GlobalVariableAssign => visitGlobalVariableAssign(obj) + case AstType.Hash => visitHash(obj) + case AstType.HashPattern => visitHashPattern(obj) + case AstType.Identifier => visitIdentifier(obj) + case AstType.IfGuard => visitIfGuard(obj) + case AstType.IfStatement => visitIfStatement(obj) + case AstType.InclusiveFlipFlop => visitInclusiveFlipFlop(obj) + case AstType.InclusiveRange => visitInclusiveRange(obj) + case AstType.InPattern => visitInPattern(obj) + case AstType.Int => visitInt(obj) + case AstType.InstanceVariable => visitInstanceVariable(obj) + case AstType.InstanceVariableAssign => visitSingleAssignment(obj) + case AstType.KwArg => visitKwArg(obj) + case AstType.KwBegin => visitKwBegin(obj) + case AstType.KwNilArg => visitKwNilArg(obj) + case AstType.KwOptArg => visitKwOptArg(obj) + case AstType.KwRestArg => visitKwRestArg(obj) + case AstType.KwSplat => visitKwSplat(obj) + case AstType.LocalVariable => visitLocalVariable(obj) + case AstType.LocalVariableAssign => visitSingleAssignment(obj) + case AstType.MatchAlt => visitMatchAlt(obj) + case AstType.MatchAs => visitMatchAs(obj) + case AstType.MatchNilPattern => visitMatchNilPattern(obj) + case AstType.MatchPattern => visitMatchPattern(obj) + case AstType.MatchPatternP => visitMatchPatternP(obj) + case AstType.MatchRest => visitMatchRest(obj) + case AstType.MatchVariable => visitMatchVariable(obj) + case AstType.MatchWithLocalVariableAssign => visitMatchWithLocalVariableAssign(obj) + case AstType.MethodDefinition => visitMethodDefinition(obj) + case AstType.ModuleDefinition => visitModuleDefinition(obj) + case AstType.MultipleAssignment => visitMultipleAssignment(obj) + case AstType.MultipleLeftHandSide => visitMultipleLeftHandSide(obj) + case AstType.Next => visitNext(obj) + case AstType.Nil => visitNil(obj) + case AstType.NthRef => visitNthRef(obj) + case AstType.OperatorAssign => visitOperatorAssign(obj) + case AstType.OptionalArgument => visitOptionalArgument(obj) + case AstType.Or => visitOr(obj) + case AstType.OrAssign => visitOrAssign(obj) + case AstType.Pair => visitPair(obj) + case AstType.PostExpression => visitPostExpression(obj) + case AstType.PreExpression => visitPreExpression(obj) + case AstType.ProcArgument => visitProcArgument(obj) + case AstType.Rational => visitRational(obj) + case AstType.Redo => visitRedo(obj) + case AstType.Retry => visitRetry(obj) + case AstType.Return => visitReturn(obj) + case AstType.RegexExpression => visitRegexExpression(obj) + case AstType.RegexOption => visitRegexOption(obj) + case AstType.ResBody => visitResBody(obj) + case AstType.RestArg => visitRestArg(obj) + case AstType.RescueStatement => visitRescueStatement(obj) + case AstType.ScopedConstant => visitScopedConstant(obj) + case AstType.Self => visitSelf(obj) + case AstType.Send => visitSend(obj) + case AstType.ShadowArg => visitShadowArg(obj) + case AstType.SingletonMethodDefinition => visitSingletonMethodDefinition(obj) + case AstType.SingletonClassDefinition => visitSingletonClassDefinition(obj) + case AstType.Splat => visitSplat(obj) + case AstType.StaticString => visitStaticString(obj) + case AstType.StaticSymbol => visitStaticSymbol(obj) + case AstType.Super => visitSuper(obj) + case AstType.SuperNoArgs => visitSuperNoArgs(obj) + case AstType.TopLevelConstant => visitTopLevelConstant(obj) + case AstType.True => visitTrue(obj) + case AstType.UnDefine => visitUnDefine(obj) + case AstType.UnlessExpression => visitUnlessExpression(obj) + case AstType.UnlessGuard => visitUnlessGuard(obj) + case AstType.UntilExpression => visitUntilExpression(obj) + case AstType.UntilPostExpression => visitUntilPostExpression(obj) + case AstType.WhenStatement => visitWhenStatement(obj) + case AstType.WhileStatement => visitWhileStatement(obj) + case AstType.WhilePostStatement => visitWhileStatement(obj) + case AstType.Yield => visitYield(obj) + } + } + + val astTypeStr = obj(ParserKeys.Type).str + AstType.fromString(astTypeStr) match { + case Some(typ) => visitAstType(typ) + case _ => + logger.warn(s"Unhandled `parser` type '$astTypeStr'") + defaultResult() + } + } + + private def visitAccessModifier(obj: Obj): RubyExpression = { + obj(ParserKeys.Name).str match { + case "public" => PublicModifier()(obj.toTextSpan) + case "private" => PrivateModifier()(obj.toTextSpan) + case "protected" => ProtectedModifier()(obj.toTextSpan) + case modifierName => + logger.warn(s"Unknown modifier type $modifierName") + defaultResult(Option(obj.toTextSpan)) + } + } + + private def visitAlias(obj: Obj): RubyExpression = { + if (AstType.fromString(obj(ParserKeys.Type).str).contains(AstType.Send)) { + obj.visitArray(ParserKeys.Arguments) match { + case name :: alias :: _ => // different order than the normal `alias` kw + AliasStatement(alias.text.stripPrefix(":"), name.text.stripPrefix(":"))(obj.toTextSpan) + case _ => defaultResult(Option(obj.toTextSpan)) + } + } else { + val name = visit(obj(ParserKeys.Name)).text.stripPrefix(":") + val alias = visit(obj(ParserKeys.Alias)).text.stripPrefix(":") + AliasStatement(alias, name)(obj.toTextSpan) + } + + } + + private def visitAnd(obj: Obj): RubyExpression = { + val op = "&&" + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + BinaryExpression(lhs, op, rhs)(obj.toTextSpan) + } + + private def visitAndAssign(obj: Obj): RubyExpression = { + val lhs = visit(obj(ParserKeys.Lhs)) match { + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + } + val rhs = visit(obj(ParserKeys.Rhs)) + OperatorAssignment(lhs, "&&=", rhs)(obj.toTextSpan) + } + + private def visitArg(obj: Obj): RubyExpression = MandatoryParameter(obj(ParserKeys.Value).str)(obj.toTextSpan) + + private def visitArgs(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitArray(obj: Obj): RubyExpression = { + val children = obj.visitArray(ParserKeys.Children).flatMap { + case x: AssociationList => x.elements + case x => x :: Nil + } + + ArrayLiteral(children)(obj.toTextSpan) + } + + private def visitArrayPattern(obj: Obj): RubyExpression = { + val children = obj.visitArray(ParserKeys.Children) + ArrayPattern(children)(obj.toTextSpan) + } + + private def visitArrayPatternWithTail(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitBackRef(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitBegin(obj: Obj): RubyExpression = { + StatementList(obj.visitArray(ParserKeys.Body))(obj.toTextSpan) + } + + private def visitGroupedParameter(arrayParam: ArrayLiteral): RubyExpression = { + val freshTmpVar = variableNameGen.fresh + val tmpMandatoryParam = MandatoryParameter(freshTmpVar)(arrayParam.span.spanStart(freshTmpVar)) + + val singleAssignments = arrayParam.elements.map { param => + val rhsSplattingNode = SplattingRubyNode(tmpMandatoryParam)(arrayParam.span.spanStart(s"*$freshTmpVar")) + val lhs = param match { + case x: SimpleIdentifier => SimpleIdentifier()(x.span) + case x: ArrayParameter => + SplattingRubyNode(SimpleIdentifier()(arrayParam.span.spanStart(x.span.text.stripPrefix("*"))))( + arrayParam.span.spanStart(x.span.text) + ) + case x: ArrayLiteral => + visitGroupedParameter(x) + case x => + logger.warn( + s"Invalid parameter type in grouped parameter list: ${x.getClass} (code: ${arrayParam.span.text})" + ) + defaultResult(Option(arrayParam.span)) + } + SingleAssignment(lhs, "=", rhsSplattingNode)( + arrayParam.span.spanStart(s"${lhs.span.text} = ${rhsSplattingNode.span.text}") + ) + } + + GroupedParameter( + tmpMandatoryParam.span.text, + tmpMandatoryParam, + GroupedParameterDesugaring(singleAssignments)(arrayParam.span) + )(arrayParam.span) + } + + private def visitBlock(obj: Obj): RubyExpression = { + val parameters = obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj].visitArray(ParserKeys.Children).map { + case x: ArrayLiteral => visitGroupedParameter(x) + case x => x + } + + val assignments = parameters.collect { case x: GroupedParameter => + x.multipleAssignment + } + + val body = obj.visitOption(ParserKeys.Body) match { + case Some(stmt: StatementList) => stmt.copy(stmt.statements ++ assignments)(stmt.span) + case Some(expr) => StatementList(expr +: assignments)(expr.span) + case None => StatementList(Nil)(obj.toTextSpan) + } + + val block = Block(parameters, body)(body.span.spanStart(obj.toTextSpan.text)) + visit(obj(ParserKeys.CallName)) match { + case classNew: ObjectInstantiation if classNew.span.text == "Class.new" => + AnonymousClassDeclaration(freshClassName(obj.toTextSpan), None, block.toStatementList)(obj.toTextSpan) + case objNew: ObjectInstantiation => objNew.withBlock(block) + case lambda: SimpleIdentifier if lambda.text == "lambda" => ProcOrLambdaExpr(block)(obj.toTextSpan) + case ident: SimpleIdentifier if ident.span.text == "loop" => + val trueLiteral = StaticLiteral(Defines.prefixAsCoreType(Defines.TrueClass))(ident.span.spanStart("true")) + DoWhileExpression(trueLiteral, body)(ident.span) + case simpleIdentifier: SimpleIdentifier => + SimpleCall(simpleIdentifier, Nil)(obj.toTextSpan).withBlock(block) + case simpleCall: RubyCall => simpleCall.withBlock(block) + case memberAccess @ MemberAccess(target, op, memberName) => + val memberCall = MemberCall(target, op, memberName, List.empty)(memberAccess.span) + memberCall.withBlock(block) + case x: ProtectedModifier => + SimpleCall(x.toSimpleIdentifier, Nil)(obj.toTextSpan).withBlock(block) + case x => + logger.warn(s"Unexpected call type used for block ${x.getClass}, ignoring block") + x + } + } + + private def visitBlockArg(obj: Obj): RubyExpression = { + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).strOpt.filterNot(_ == "&").getOrElse(procParamGen.fresh.value) + ProcParameter(name)(span) + } + + private def visitBlockPass(obj: Obj): RubyExpression = { + lazy val default = SimpleIdentifier()(obj.toTextSpan.spanStart(procParamGen.current.value)) + obj.visitOption(ParserKeys.Value).getOrElse(default) + } + + private def visitBlockWithNumberedParams(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitBracketAssignmentAsSend(obj: Obj): RubyExpression = { + val lhsBase = visit(obj(ParserKeys.Receiver)) + val args = obj.visitArray(ParserKeys.Arguments) + + val lhs = + IndexAccess(lhsBase, List(args.head))(obj.toTextSpan.spanStart(s"${lhsBase.span.text}[${args.head.span.text}]")) + + val rhs = + if args.size == 2 then args(1) + else SimpleIdentifier()(obj.toTextSpan.spanStart("*")) + + SingleAssignment(lhs, "=", rhs)(obj.toTextSpan) + } + + private def visitBreak(obj: Obj): RubyExpression = BreakExpression()(obj.toTextSpan) + + private def visitCaseExpression(obj: Obj): RubyExpression = { + val expression = obj.visitOption(ParserKeys.CaseExpression) + val whenClauses = obj.visitArray(ParserKeys.WhenClauses) + + val elseClause = obj.visitOption(ParserKeys.ElseClause) match { + case Some(elseClause) => Some(ElseClause(elseClause)(elseClause.span)) + case None => None + } + + CaseExpression(expression, whenClauses, elseClause)(obj.toTextSpan) + } + + private def visitCaseMatchStatement(obj: Obj): RubyExpression = { + val expression = visit(obj(ParserKeys.Statement)) + val inClauses = obj.visitArray(ParserKeys.Bodies) + val elseClause = obj.visitOption(ParserKeys.ElseClause).map(x => ElseClause(x)(x.span)) + + CaseExpression(Some(expression), inClauses, elseClause)(obj.toTextSpan) + } + + private def visitClassDefinition(obj: Obj): RubyExpression = { + val (name, namespaceParts) = visit(obj(ParserKeys.Name)) match { + case memberAccess: MemberAccess => + val memberIdentifier = SimpleIdentifier()(memberAccess.span.spanStart(memberAccess.memberName)) + (memberIdentifier, Option(getParts(memberAccess).dropRight(1))) + case identifier => (identifier, None) + } + val baseClass = obj.visitOption(ParserKeys.SuperClass) + val (body, fields) = createClassBodyAndFields(obj) + val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan) + ClassDeclaration( + name = name, + baseClass = baseClass, + body = body, + fields = fields, + bodyMemberCall = Option(bodyMemberCall), + namespaceParts = namespaceParts + )(obj.toTextSpan) + } + + private def visitClassVariable(obj: Obj): RubyExpression = ClassFieldIdentifier()(obj.toTextSpan) + + private def visitCollectionAliasSend(obj: Obj): RubyExpression = { + // Modify this `obj` to conform to what the AstCreator would expect i.e, Array [1,2,3] would be an Array::[] call + val collectionName = obj(ParserKeys.Name).str + val metaData = obj(ParserKeys.MetaData) + metaData.obj.put(ParserKeys.Code, collectionName) + val receiver = ujson.Obj( + ParserKeys.Type -> ujson.Str(AstType.ScopedConstant.name), + ParserKeys.MetaData -> metaData, + ParserKeys.Base -> ujson.Null, + ParserKeys.Name -> ujson.Str(collectionName) + ) + val arguments = obj(ParserKeys.Arguments).arr.headOption + .flatMap { + case x: ujson.Obj => AstType.fromString(x(ParserKeys.Type).str).map(t => t -> x) + case _ => None + } + .map { + case (AstType.Array, o) => + o.visitArray(ParserKeys.Children).flatMap { + case x: AssociationList => x.elements + case x => x :: Nil + } + case (_, o) => + visit(o) :: Nil + } + .getOrElse(Nil) + + val textSpan = obj.toTextSpan.spanStart(s"$collectionName [${arguments.map(_.span.text).mkString(", ")}]") + + IndexAccess(visit(receiver), arguments)(textSpan) + } + + private def visitDefined(obj: Obj): RubyExpression = { + val name = + SimpleIdentifier(Option(Defines.prefixAsKernelDefined(Defines.Defined)))( + obj.toTextSpan.spanStart(Defines.Defined) + ) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(name, arguments)(obj.toTextSpan) + } + + private def visitDynamicString(obj: Obj): RubyExpression = { + val typeFullName = Defines.prefixAsCoreType(Defines.String) + val expressions = obj.visitArray(ParserKeys.Children) + DynamicLiteral(typeFullName, expressions)(obj.toTextSpan) + } + + private def visitDynamicSymbol(obj: Obj): RubyExpression = { + val typeFullName = Defines.prefixAsCoreType(Defines.Symbol) + val expressions = obj.visitArray(ParserKeys.Children) + DynamicLiteral(typeFullName, expressions)(obj.toTextSpan) + } + + private def visitEnsure(obj: Obj): RubyExpression = { + val ensureClause = EnsureClause(visit(obj(ParserKeys.Body)))(obj.toTextSpan) + visit(obj(ParserKeys.Statement)) match { + case rescueExpression: RescueExpression => + rescueExpression.copy( + rescueExpression.body, + rescueExpression.rescueClauses, + rescueExpression.elseClause, + Some(ensureClause) + )(obj.toTextSpan) + case x => + RescueExpression(x, List.empty, Option.empty, Some(ensureClause))(obj.toTextSpan) + } + } + + private def visitExclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitExclusiveRange(obj: Obj): RubyExpression = { + val start = visit(obj(ParserKeys.Start)) + val end = visit(obj(ParserKeys.End)) + val op = RangeOperator(true)(obj.toTextSpan.spanStart("...")) + RangeExpression(start, end, op)(obj.toTextSpan) + } + + private def visitExecutableString(obj: Obj): RubyExpression = { + val operatorName = RubyOperators.backticks + val callName = + SimpleIdentifier(Option(Defines.prefixAsKernelDefined(operatorName)))(obj.toTextSpan.spanStart(operatorName)) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(callName, arguments)(obj.toTextSpan) + } + + private def visitFalse(obj: Obj): RubyExpression = + StaticLiteral(Defines.prefixAsCoreType(Defines.FalseClass))(obj.toTextSpan) + + private def visitFieldDeclaration(obj: Obj): RubyExpression = { + val arguments = obj.visitArray(ParserKeys.Arguments) + val accessType = obj(ParserKeys.Name).str + FieldsDeclaration(arguments, accessType)(obj.toTextSpan) + } + + private def visitFindPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitFieldAssignmentSend(obj: Obj, fieldName: String): RubyExpression = { + val span = obj.toTextSpan + val receiver = visit(obj(ParserKeys.Receiver)) + val memberAccess = MemberAccess(receiver, ".", fieldName)(receiver.span.spanStart(s"${receiver.text}.@$fieldName")) + val argument = obj + .visitArray(ParserKeys.Arguments) + .headOption + .getOrElse(StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(span.spanStart("nil"))) + SingleAssignment(memberAccess, "=", argument)(span) + } + + private def visitFloat(obj: Obj): RubyExpression = + StaticLiteral(Defines.prefixAsCoreType(Defines.Float))(obj.toTextSpan) + + private def visitForStatement(obj: Obj): RubyExpression = { + val forVariable = visit(obj(ParserKeys.Variable)) + val iterableVariable = visit(obj(ParserKeys.Collection)) + val doBlock = visit(obj(ParserKeys.Body)) match { + case stmtList: StatementList => stmtList + case other => StatementList(List(other))(other.span) + } + + ForExpression(forVariable, iterableVariable, doBlock)(obj.toTextSpan) + } + + private def visitForwardArg(obj: Obj): RubyExpression = { + logger.warn("Forward arg unhandled") + defaultResult(Option(obj.toTextSpan)) + } + + // Note: Forward args should probably be handled more explicitly, but this should preserve flows if the same + // identifier is used in latter forwarding + private def visitForwardArgs(obj: Obj): RubyExpression = MandatoryParameter("...")(obj.toTextSpan) + + private def visitForwardedArgs(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitGlobalVariable(obj: Obj): RubyExpression = { + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).str + val selfBase = SelfIdentifier()(span.spanStart("self")) + MemberAccess(selfBase, ".", name)(span) + } + + private def visitGlobalVariableAssign(obj: Obj): RubyExpression = { + val span = obj.toTextSpan + + val selfBase = SelfIdentifier()(span.spanStart("self")) + val lhsName = obj(ParserKeys.Lhs).str + val lhs = MemberAccess(selfBase, ".", lhsName)(span.spanStart(s"${selfBase.span.text}.$lhsName")) + + val rhs = visit(obj(ParserKeys.Rhs)) + val op = "=" + + SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + } + + private def visitHash(obj: Obj): RubyExpression = { + val isHashLiteral = obj.toTextSpan.text.stripMargin.startsWith("{") + + obj.visitArray(ParserKeys.Children) match { + case (assoc: Association) :: Nil => + if isHashLiteral then HashLiteral(List(assoc))(obj.toTextSpan) + else assoc // 2 => 1 is interpreted as {2: 1}, so we lower this for now + case children => + if isHashLiteral then HashLiteral(children)(obj.toTextSpan) + else AssociationList(children)(obj.toTextSpan) + } + } + + private def visitHashPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitIdentifier(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitIfGuard(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitIfStatement(obj: Obj): RubyExpression = { + val condition = visit(obj(ParserKeys.Condition)) + + val elseClause = obj.visitOption(ParserKeys.ElseBranch).map { + case x: IfExpression => x + case x => ElseClause(StatementList(List(x))(x.span))(x.span) + } + + obj.visitOption(ParserKeys.ThenBranch) match { + case Some(thenBranch) => + IfExpression(condition, thenBranch, elsifClauses = List.empty, elseClause)(obj.toTextSpan) + case None => + val nilBlock = ReturnExpression( + List(StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(obj.toTextSpan.spanStart("nil"))) + )(obj.toTextSpan.spanStart("return nil")) + IfExpression(condition, nilBlock, elsifClauses = List.empty, elseClause)(obj.toTextSpan) + } + } + + private def visitInclude(obj: Obj): RubyExpression = { + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argument = obj.visitArray(ParserKeys.Arguments).head + + IncludeCall(target, argument)(obj.toTextSpan) + } + + private def visitInclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitInclusiveRange(obj: Obj): RubyExpression = { + val start = obj.visitOption(ParserKeys.Start) match { + case Some(expr) => expr + case None => infinityLowerBound(obj) + } + val end = obj.visitOption(ParserKeys.End) match { + case Some(expr) => expr + case None => infinityUpperBound(obj) + } + val op = RangeOperator(false)(obj.toTextSpan.spanStart("..")) + RangeExpression(start, end, op)(obj.toTextSpan) + } + + private def visitIndexAccessAsSend(obj: Obj): RubyExpression = { + val target = visit(obj(ParserKeys.Receiver)) + val indices = obj.visitArray(ParserKeys.Arguments) + val isRegexMatch = indices.headOption.exists { + case x: StaticLiteral => x.typeFullName == Defines.prefixAsCoreType(Defines.Regexp) + case _ => false + } + if (isRegexMatch) { + // For regex match that looks like "hello"[/h(el)lo/] + val newProps = obj.value + newProps.put(ParserKeys.Name, RubyOperators.regexpMatch) + visitBinaryExpression(obj.copy(value = newProps)) + } else { + IndexAccess(target, indices)(obj.toTextSpan) + } + } + + private def visitInPattern(obj: Obj): RubyExpression = { + val patternType = visit(obj(ParserKeys.Pattern)) + val patternBody = visit(obj(ParserKeys.Body)) + + InClause(patternType, patternBody)(obj.toTextSpan) + } + + private def visitInt(obj: Obj): RubyExpression = { + val typeFullName = Defines.prefixAsCoreType(Defines.Integer) + StaticLiteral(typeFullName)(obj.toTextSpan) + } + + private def visitInstanceVariable(obj: Obj): RubyExpression = InstanceFieldIdentifier()(obj.toTextSpan) + + private def visitKwArg(obj: Obj): RubyExpression = { + val name = obj(ParserKeys.Key).str + val default = obj + .visitOption(ParserKeys.Value) + .getOrElse(StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(obj.toTextSpan.spanStart("nil"))) + OptionalParameter(name, default)(obj.toTextSpan) + } + + private def visitKwBegin(obj: Obj): RubyExpression = { + val stmts = obj(ParserKeys.Body) match { + case o: Obj => visit(o) :: Nil + case _: Arr => obj.visitArray(ParserKeys.Body) + case _ => + val span = obj.toTextSpan + logger.warn(s"Unhandled JSON body type for `KwBegin`: ${span.text}") + defaultResult(Option(span)) :: Nil + } + StatementList(stmts)(obj.toTextSpan) + } + + private def visitKwNilArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitKwOptArg(obj: Obj): RubyExpression = visitKwArg(obj) + + private def visitKwRestArg(obj: Obj): RubyExpression = { + val name = if obj.contains(ParserKeys.Value) then obj(ParserKeys.Value).str else obj.toTextSpan.text + HashParameter(name)(obj.toTextSpan) + } + + private def visitKwSplat(obj: Obj): RubyExpression = { + val values = visit(obj(ParserKeys.Value)) match { + case x: StatementList => x.statements.head + case x => x + } + SplattingRubyNode(values)(obj.toTextSpan) + } + + private def visitLocalVariable(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitMatchAlt(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchAs(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchNilPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchPatternP(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchRest(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchVariable(obj: Obj): RubyExpression = MatchVariable()(obj.toTextSpan) + + private def visitMatchWithLocalVariableAssign(obj: Obj): RubyExpression = { + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + MemberCall(lhs, ".", RubyOperators.regexpMatch, rhs :: Nil)(obj.toTextSpan) + } + + private def visitMethodAccessModifier(obj: Obj): RubyExpression = { + val body = obj.visitArray(ParserKeys.Arguments) match { + case head :: Nil => head + case xs => xs.head + } + + obj(ParserKeys.Name).str match { + case "public_class_method" => + PublicMethodModifier(body)(obj.toTextSpan) + case "private_class_method" => + PrivateMethodModifier(body)(obj.toTextSpan) + case modifierName => + logger.warn(s"Unknown modifier type $modifierName") + defaultResult(Option(obj.toTextSpan)) + } + } + + private def visitMethodDefinition(obj: Obj): RubyExpression = { + val name = obj(ParserKeys.Name).str + val parameters = visitMethodParameters(obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj]) + val body = obj + .visitOption(ParserKeys.Body) + .map { + case x: StatementList => x + case x => StatementList(List(x))(x.span) + } + .getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart(""))) + MethodDeclaration(name, parameters, body)(obj.toTextSpan) + } + + private def visitModuleDefinition(obj: Obj): RubyExpression = { + val (name, namespaceParts) = visit(obj(ParserKeys.Name)) match { + case memberAccess: MemberAccess => + val memberIdentifier = SimpleIdentifier()(memberAccess.span.spanStart(memberAccess.memberName)) + (memberIdentifier, Option(getParts(memberAccess).dropRight(1))) + case identifier => (identifier, None) + } + val (body, fields) = createClassBodyAndFields(obj) + val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan) + ModuleDeclaration( + name = name, + body = body, + fields = fields, + bodyMemberCall = Option(bodyMemberCall), + namespaceParts = namespaceParts + )(obj.toTextSpan) + } + + private def visitMultipleAssignment(obj: Obj): RubyExpression = { + val lhs = visit(obj(ParserKeys.Lhs)) match { + case _ @ArrayLiteral(elements) => elements + case expr => expr :: Nil + } + val rhs = visit(obj(ParserKeys.Rhs)) match { + case _ @ArrayLiteral(elements) => elements + case expr => expr :: Nil + } + lowerMultipleAssignment( + obj, + lhs, + rhs, + () => defaultResult(), + () => StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(obj.toTextSpan) + ) + } + + private def visitMultipleLeftHandSide(obj: Obj): RubyExpression = { + val arr = visitArray(obj).asInstanceOf[ArrayLiteral] + arr.copy(elements = arr.elements.map { + case param: MandatoryParameter => param.toSimpleIdentifier + case expr => expr + })(arr.span) + } + + private def visitNext(obj: Obj): RubyExpression = NextExpression()(obj.toTextSpan) + + private def visitNil(obj: Obj): RubyExpression = + StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(obj.toTextSpan) + + private def visitNthRef(obj: Obj): RubyExpression = { + // We represent $1 as $[1] in order to track these arbitrary numeric accesses in a way the data-flow engine + // understands + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).num.toInt + val selfBase = SelfIdentifier()(span.spanStart("self")) + val amperMemberAccess = MemberAccess(selfBase, ".", "$")(span) + val indexPos = StaticLiteral(Defines.prefixAsCoreType(Defines.Integer))(obj.toTextSpan.spanStart(name.toString)) + IndexAccess(amperMemberAccess, indexPos :: Nil)(obj.toTextSpan.spanStart(s"$$[$name]")) + } + + private def visitObjectInstantiation(obj: Obj): RubyExpression = { + // The receiver is the target with the JSON parser + val receiver = visit(obj(ParserKeys.Receiver)) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleObjectInstantiation(receiver, arguments)(obj.toTextSpan) + } + + private def visitOperatorAssign(obj: Obj): RubyExpression = { + val lhs = visit(obj(ParserKeys.Lhs)) match { + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + } + val op = s"${obj(ParserKeys.Op).str}=" + val rhs = visit(obj(ParserKeys.Rhs)) + SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + } + + private def visitOptionalArgument(obj: Obj): RubyExpression = { + val name = obj(ParserKeys.Key).str + val default = visit(obj(ParserKeys.Value)) + OptionalParameter(name, default)(obj.toTextSpan) + } + + private def visitOr(obj: Obj): RubyExpression = { + val op = "||" + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + BinaryExpression(lhs, op, rhs)(obj.toTextSpan) + } + + private def visitOrAssign(obj: Obj): RubyExpression = { + val lhs = visit(obj(ParserKeys.Lhs)) match { + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + } + val rhs = visit(obj(ParserKeys.Rhs)) + OperatorAssignment(lhs, "||=", rhs)(obj.toTextSpan) + } + + private def visitPair(obj: Obj): RubyExpression = { + val key = visit(obj(ParserKeys.Key)) + val value = visit(obj(ParserKeys.Value)) + Association(key, value)(obj.toTextSpan) + } + + private def visitMethodParameters(paramsNode: Obj): List[RubyExpression] = { + AstType.fromString(paramsNode(ParserKeys.Type).str) match { + case Some(AstType.Args) => paramsNode.visitArray(ParserKeys.Children) + case Some(AstType.ForwardArgs) => visit(paramsNode) :: Nil + case Some(x) => + logger.warn(s"Not explicitly handled parameter type '$x', no special handling applied") + visit(paramsNode) :: Nil + case _ => + logger.error(s"Unknown JSON type used as method parameter ${paramsNode(ParserKeys.Type).str}") + defaultResult(Option(paramsNode.toTextSpan)) :: Nil + } + } + + private def visitPostExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitPreExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitProcArgument(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitRaise(obj: Obj): RubyExpression = { + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + + obj.visitArray(ParserKeys.Arguments) match { + case Nil => RaiseCall(target, List.empty)(obj.toTextSpan) + case (argument: StaticLiteral) :: Nil => + val simpleErrorId = + SimpleIdentifier(Option(Defines.prefixAsCoreType("StandardError")))(argument.span.spanStart("StandardError")) + val implicitSimpleErrInst = SimpleObjectInstantiation(simpleErrorId, argument :: Nil)( + argument.span.spanStart(s"StandardError.new(${argument.text})") + ) + RaiseCall(target, implicitSimpleErrInst :: Nil)(obj.toTextSpan) + case argument :: Nil => + RaiseCall(target, List(argument))(obj.toTextSpan) + case arguments => + RaiseCall(target, arguments)(obj.toTextSpan) + } + + } + + private def visitRational(obj: Obj): RubyExpression = + StaticLiteral(Defines.prefixAsCoreType(Defines.Rational))(obj.toTextSpan) + + private def visitRedo(obj: Obj): RubyExpression = { + val callTarget = SimpleIdentifier()(obj.toTextSpan.spanStart("redo")) + SimpleCall(callTarget, Nil)(obj.toTextSpan) + } + + private def visitRetry(obj: Obj): RubyExpression = { + val callTarget = SimpleIdentifier()(obj.toTextSpan.spanStart("retry")) + SimpleCall(callTarget, Nil)(obj.toTextSpan) + } + + private def visitReturn(obj: Obj): RubyExpression = { + if (obj.contains(ParserKeys.Values)) { + val returnExpressions = obj.visitArray(ParserKeys.Values) + ReturnExpression(returnExpressions)(obj.toTextSpan) + } else if (obj.contains(ParserKeys.Value)) { + ReturnExpression(visit(obj(ParserKeys.Value)) :: Nil)(obj.toTextSpan) + } else { + ReturnExpression(List.empty)(obj.toTextSpan) + } + } + + private def visitRegexExpression(obj: Obj): RubyExpression = { + obj.visitOption(ParserKeys.Value) match { + case Some(_ @StatementList(stmts)) => + DynamicLiteral(Defines.prefixAsCoreType(Defines.Regexp), stmts)(obj.toTextSpan) + case _ => StaticLiteral(Defines.prefixAsCoreType(Defines.Regexp))(obj.toTextSpan) + } + } + + private def visitRegexOption(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitResBody(obj: Obj): RubyExpression = { + val exceptionClassList = obj.visitOption(ParserKeys.ExecList) + val variables = obj.visitOption(ParserKeys.ExecVar) + val body = obj.visitOption(ParserKeys.Body) match { + case Some(stmt: StatementList) => stmt + case Some(expr) => StatementList(expr :: Nil)(expr.span) + case None => StatementList(Nil)(obj.toTextSpan) + } + RescueClause(exceptionClassList, variables, body)(obj.toTextSpan) + } + + private def visitRestArg(obj: Obj): RubyExpression = { + obj(ParserKeys.Value) match { + case ujson.Null => ArrayParameter("*")(obj.toTextSpan) + case ujson.Str(name) => ArrayParameter(name)(obj.toTextSpan) + case x => + logger.warn(s"Unhandled `restarg` JSON type '$x'") + defaultResult(Option(obj.toTextSpan)) + } + } + + private def visitRescueStatement(obj: Obj): RubyExpression = { + val stmt = visit(obj(ParserKeys.Statement)) + val rescueClauses = obj.visitArray(ParserKeys.Bodies).asInstanceOf[List[RescueClause]] + val elseClause = obj.visitOption(ParserKeys.ElseClause) match { + case Some(body) => Option(ElseClause(body)(body.span)) + case None => Option.empty + } + + RescueExpression(stmt, rescueClauses, elseClause, Option.empty)(obj.toTextSpan) + } + + private def visitRequireLike(obj: Obj): RubyExpression = { + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argument = obj + .visitArray(ParserKeys.Arguments) + .headOption + .getOrElse(StaticLiteral(Defines.prefixAsCoreType(Defines.NilClass))(obj.toTextSpan.spanStart("nil"))) + val isRelative = callName == "require_relative" || callName == "require_all" + val isWildcard = callName == "require_all" + RequireCall(target, argument, isRelative, isWildcard)(obj.toTextSpan) + } + + private def visitScopedConstant(obj: Obj): RubyExpression = { + val identifier = obj(ParserKeys.Name).str + if (obj.contains(ParserKeys.Base)) { + val target = visit(obj(ParserKeys.Base)) + val op = if obj.toTextSpan.text.contains("::") then "::" else "." + MemberAccess(target, op, identifier)(obj.toTextSpan) + } else { + SimpleIdentifier()(obj.toTextSpan) + } + } + + private def visitSelf(obj: Obj): RubyExpression = SelfIdentifier()(obj.toTextSpan) + + private def visitBinaryExpression(obj: Obj): RubyExpression = { + val callName = obj(ParserKeys.Name).str + val lhs = visit(obj(ParserKeys.Receiver)) + val rhs = obj.visitArray(ParserKeys.Arguments).head + // Transform `match` to `=~` so that it is lowered later + val op = if RubyOperators.regexMethods.contains(callName) then RubyOperators.regexpMatch else callName + BinaryExpression(lhs, op, rhs)(obj.toTextSpan) + } + + private def visitSend(obj: Obj, isConditional: Boolean = false): RubyExpression = { + val callName = obj(ParserKeys.Name).str + val hasReceiver = obj.contains(ParserKeys.Receiver) + callName match { + case "new" => visitObjectInstantiation(obj) + case "Array" | "Hash" => visitCollectionAliasSend(obj) + case "[]" => visitIndexAccessAsSend(obj) + case "[]=" => visitBracketAssignmentAsSend(obj) + case "raise" => visitRaise(obj) + case "include" => visitInclude(obj) + case "alias_method" => visitAlias(obj) + case "attr_reader" | "attr_writer" | "attr_accessor" => visitFieldDeclaration(obj) + case "private" | "public" | "protected" => visitAccessModifier(obj) + case "private_class_method" | "public_class_method" => visitMethodAccessModifier(obj) + case requireLike if ImportCallNames.contains(requireLike) && !hasReceiver => visitRequireLike(obj) + case x + if BinaryOperators.isBinaryOperatorName(callName) + || RubyOperators.regexMethods.contains(x) => // assert `match`, `sub`, or `gsub` is always for regex + visitBinaryExpression(obj) + case _ if UnaryOperators.isUnaryOperatorName(callName) => + UnaryExpression(callName, visit(obj(ParserKeys.Receiver)))(obj.toTextSpan) + case s"$name=" if hasReceiver => visitFieldAssignmentSend(obj, name) + case _ => + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argumentArr = obj.visitArray(ParserKeys.Arguments) + val arguments = argumentArr.flatMap { + case hashLiteral: HashLiteral => hashLiteral.elements // a hash is likely named arguments + case assocList: AssociationList => assocList.elements // same as above + case x => x :: Nil + } + val objSpan = obj.toTextSpan + val hasArguments = arguments.nonEmpty + val usesParenthesis = objSpan.text.endsWith(")") + if (obj.contains(ParserKeys.Receiver)) { + val base = visit(obj(ParserKeys.Receiver)) + val isMemberCall = usesParenthesis || callName == "<<" || hasArguments + val op = { + val dot = if objSpan.text.stripPrefix(base.text).startsWith("::") then "::" else "." + if isConditional then s"&$dot" else dot + } + if isMemberCall then MemberCall(base, op, callName, arguments)(obj.toTextSpan) + else MemberAccess(base, op, callName)(obj.toTextSpan) + } else if (hasArguments || usesParenthesis) { + SimpleCall(target, arguments)(obj.toTextSpan) + } else { + // The following allows the AstCreator to approximate when an identifier could be a call or not - puts less + // strain on data-flow tracking for externally inherited accessor calls such as `params` in RubyOnRails + SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + } + } + } + + private def visitShadowArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = { + val base = visit(obj(ParserKeys.Base)) + val name = obj(ParserKeys.Name).str + val parameters = visitMethodParameters(obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj]) + val body = + obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart(""))) match { + case stmtList: StatementList => stmtList + case expr => StatementList(expr :: Nil)(expr.span) + } + SingletonMethodDeclaration(base, name, parameters, body)(obj.toTextSpan) + } + + private def visitSingletonClassDefinition(obj: Obj): RubyExpression = { + val name = visit(obj(ParserKeys.Name)) + val baseClass = obj.visitOption(ParserKeys.SuperClass) + val body = obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart(""))) + + obj.visitOption(ParserKeys.Def) match { + case Some(body) => + name match { + case _: SelfIdentifier => + val bodyList = body match { + case stmtList: StatementList => stmtList + case expr => StatementList(expr :: Nil)(expr.span) + } + + val base = baseClass match { + case Some(baseClass) => baseClass + case None => SelfIdentifier()(obj.toTextSpan.spanStart("self")) + } + + SingletonClassDeclaration(freshClassName(obj.toTextSpan), Some(base), bodyList)(obj.toTextSpan) + case _ => + def mapDefBody(defBody: RubyExpression): RubyExpression = defBody match { + case method @ MethodDeclaration(methodName, parameters, body) => + val memberAccess = + MemberAccess(name, ".", methodName)(method.span.spanStart(s"${name.span.text}.${methodName}")) + val singletonBlockMethod = + SingletonObjectMethodDeclaration(methodName, parameters, body, name)(method.span) + SingleAssignment(memberAccess, "=", singletonBlockMethod)( + method.span.spanStart(s"${memberAccess.span.text} = ${method.span.text}") + ) + case expr => expr + } + + val stmts = body match { + case _ @StatementList(stmts) => stmts.map(mapDefBody) + case expr => mapDefBody(expr) :: Nil + } + SingletonStatementList(stmts)(obj.toTextSpan) + } + + case None => + val anonName = freshClassName(obj.toTextSpan) + SingletonClassDeclaration(name = anonName, baseClass = baseClass, body = body)(obj.toTextSpan) + } + } + + private def visitSingleAssignment(obj: Obj): RubyExpression = { + val lhsSpan = obj.toTextSpan.spanStart(obj(ParserKeys.Lhs).str) + val lhs = obj(ParserKeys.Lhs).str match { + case s"@@$_" => ClassFieldIdentifier()(lhsSpan) + case s"@$_" => InstanceFieldIdentifier()(lhsSpan) + case _ => SimpleIdentifier()(lhsSpan) + } + obj.visitOption(ParserKeys.Rhs) match { + case Some(rhs) => + SingleAssignment(lhs, "=", rhs)(obj.toTextSpan) + case None => + if (AstType.fromString(obj(ParserKeys.Type).str) == AstType.LocalVariableAssign) { + // `lvasgn` is used in exec_var for rescueExpr, which only has LHS + MandatoryParameter(lhs.span.text)(lhs.span) + } else { + lhs + } + } + } + + private def visitSplat(obj: Obj): RubyExpression = { + obj.visitOption(ParserKeys.Value) match { + case Some(x) => SplattingRubyNode(x)(obj.toTextSpan) + case None => + val emptyStar = SimpleIdentifier()(obj.toTextSpan.spanStart("_")) + SplattingRubyNode(emptyStar)(obj.toTextSpan) + } + } + + private def visitStaticString(obj: Obj): RubyExpression = { + val typeFullName = Defines.prefixAsCoreType(Defines.String) + val originalSpan = obj.toTextSpan + val value = obj(ParserKeys.Value).str + // In general, we want the quotations, unless it is a HEREDOC string, then we'd prefer the value + val span = if !originalSpan.text.contains(value) then originalSpan.spanStart(value) else originalSpan + StaticLiteral(typeFullName)(span) + } + + private def visitStaticSymbol(obj: Obj): RubyExpression = { + val typeFullName = Defines.prefixAsCoreType(Defines.Symbol) + val objTextSpan = obj.toTextSpan + + if objTextSpan.text.startsWith(":") then StaticLiteral(typeFullName)(obj.toTextSpan) + else StaticLiteral(typeFullName)(objTextSpan.spanStart(s":${objTextSpan.text}")) + } + + private def visitSuper(obj: Obj): RubyExpression = { + val name = + SimpleIdentifier(Option(Defines.prefixAsKernelDefined(Defines.Super)))(obj.toTextSpan.spanStart(Defines.Super)) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(name, arguments)(obj.toTextSpan) + } + + private def visitSuperNoArgs(obj: Obj): RubyExpression = { + val name = + SimpleIdentifier(Option(Defines.prefixAsKernelDefined(Defines.Super)))(obj.toTextSpan.spanStart(Defines.Super)) + SimpleCall(name, Nil)(obj.toTextSpan) + } + + private def visitTopLevelConstant(obj: Obj): RubyExpression = { + if (obj.contains(ParserKeys.Name)) { + val identifier = obj(ParserKeys.Name).str + SimpleIdentifier()(obj.toTextSpan.spanStart(identifier)) + } else { + SelfIdentifier()(obj.toTextSpan.spanStart("self")) + } + } + + private def visitTrue(obj: Obj): RubyExpression = + StaticLiteral(Defines.prefixAsCoreType(Defines.TrueClass))(obj.toTextSpan) + + private def visitUnDefine(obj: Obj): RubyExpression = { + defaultResult(Option(obj.toTextSpan)) + } + + private def visitUnlessExpression(obj: Obj): RubyExpression = { + defaultResult(Option(obj.toTextSpan)) + } + + private def visitUnlessGuard(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitUntilExpression(obj: Obj): RubyExpression = { + val condition = visit(obj(ParserKeys.Condition)) + val body = visit(obj(ParserKeys.Body)) + + UntilExpression(condition, body)(obj.toTextSpan) + } + + private def visitUntilPostExpression(obj: Obj): RubyExpression = { + val condition = visit(obj(ParserKeys.Condition)) + val body = visit(obj(ParserKeys.Body)) + + DoWhileExpression(condition, body)(obj.toTextSpan) + } + + private def visitWhenStatement(obj: Obj): RubyExpression = { + val (matchCondition, matchSplatCondition) = obj.visitArray(ParserKeys.Conditions).partition { + case x: SplattingRubyNode => false + case x => true + } + + val thenClause = visit(obj(ParserKeys.ThenBranch)) + + WhenClause(matchCondition, matchSplatCondition.headOption, thenClause)(obj.toTextSpan) + } + + private def visitWhileStatement(obj: Obj): RubyExpression = { + val condition = visit(obj(ParserKeys.Condition)) match { + case x: StatementList => x.statements.head + case x => x + } + + val body = visit(obj(ParserKeys.Body)) + + WhileExpression(condition, body)(obj.toTextSpan) + } + + private def visitYield(obj: Obj): RubyExpression = { + val arguments = obj.visitArray(ParserKeys.Arguments) + YieldExpr(arguments)(obj.toTextSpan) + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerBase.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerBase.scala deleted file mode 100644 index d3aea681c180..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerBase.scala +++ /dev/null @@ -1,49 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import io.joern.rubysrc2cpg.parser.RubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF -import org.antlr.v4.runtime.{CharStream, Lexer, Token} - -/** Aggregates auxiliary features to RubyLexer in a single place. */ -abstract class RubyLexerBase(input: CharStream) - extends Lexer(input) - with RegexLiteralHandling - with InterpolationHandling - with QuotedLiteralHandling - with HereDocHandling { - - /** The previously (non-WS) emitted token (in DEFAULT_CHANNEL.) */ - protected var previousNonWsToken: Option[Token] = None - - /** The previously emitted token (in DEFAULT_CHANNEL.) */ - protected var previousToken: Option[Token] = None - - // Same original behaviour, just updating `previous{NonWs}Token`. - override def nextToken: Token = { - val token: Token = super.nextToken - if (token.getChannel == Token.DEFAULT_CHANNEL && token.getType != WS) { - previousNonWsToken = Some(token) - } - previousToken = Some(token) - token - } - - def previousNonWsTokenTypeOrEOF(): Int = { - previousNonWsToken.map(_.getType).getOrElse(EOF) - } - - def previousTokenTypeOrEOF(): Int = { - previousToken.map(_.getType).getOrElse(EOF) - } - - def isNumericTokenType(tokenType: Int): Boolean = { - val numericTokenTypes = Set( - DECIMAL_INTEGER_LITERAL, - OCTAL_INTEGER_LITERAL, - HEXADECIMAL_INTEGER_LITERAL, - FLOAT_LITERAL_WITHOUT_EXPONENT, - FLOAT_LITERAL_WITH_EXPONENT - ) - numericTokenTypes.contains(tokenType) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerPostProcessor.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerPostProcessor.scala deleted file mode 100644 index bdf5e6727cd2..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyLexerPostProcessor.scala +++ /dev/null @@ -1,74 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import io.joern.rubysrc2cpg.parser.RubyLexer.* -import org.antlr.v4.runtime.Recognizer.EOF -import org.antlr.v4.runtime.misc.Pair -import org.antlr.v4.runtime.{CommonToken, ListTokenSource, Token, TokenSource} - -import scala.:: -import scala.jdk.CollectionConverters.* - -/** Simplifies the token stream obtained from `RubyLexer`. - */ -object RubyLexerPostProcessor { - - def apply(tokenSource: TokenSource): ListTokenSource = { - var tokens = tokenSource.toSeq - - tokens = tokens.mergeConsecutive(NON_EXPANDED_LITERAL_CHARACTER, NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE) - tokens = tokens.mergeConsecutive(EXPANDED_LITERAL_CHARACTER, EXPANDED_LITERAL_CHARACTER_SEQUENCE) - tokens = tokens.filterNot(_.is(WS)) - - new ListTokenSource(tokens.asJava) - } -} - -private implicit class TokenSourceExt(val tokenSource: TokenSource) { - - def toSeq: Seq[Token] = Seq.unfold(tokenSource) { tkSrc => - tkSrc.nextToken() match - case tk if tk.is(EOF) => None - case tk => Some((tk, tkSrc)) - } -} - -private implicit class SeqExt[A](val elems: Seq[A]) { - - /** An order-preserving `groupBy` implemented on top of `Seq`. Each sub-sequence ("chain") contains 1+ elements. If a - * chain contains 2+ elements, then all its elements satisfy `p`. Flattening returns the original sequence. - */ - def chains(p: A => Boolean): Seq[Seq[A]] = elems.foldRight(Nil: Seq[Seq[A]]) { (h, t) => - t match - case chain :: chains if chain.exists(p) && p(h) => (h +: chain) +: chains - case _ => Seq(h) +: t - } - - /** Collapses, according to a merging operation `m`, all chains that verify `p`. - */ - def mergeChains(p: A => Boolean, m: Seq[A] => A): Seq[A] = { - elems.chains(p).flatMap(chain => if (chain.exists(p)) Seq(m(chain)) else chain) - } - -} - -private implicit class TokenSeqExt(val tokens: Seq[Token]) { - - def mergeAs(tokenType: Int): Token = { - val startIndex = tokens.head.getStartIndex - val stopIndex = tokens.last.getStopIndex - val tokenSource = tokens.head.getTokenSource - val inputStream = tokens.head.getInputStream - val channel = tokens.head.getChannel - new CommonToken(new Pair(tokenSource, inputStream), tokenType, channel, startIndex, stopIndex) - } - - def mergeConsecutive(oldTokenType: Int, newTokenType: Int): Seq[Token] = { - tokens.mergeChains(_.is(oldTokenType), _.mergeAs(newTokenType)) - } -} - -private implicit class TokenExt(val token: Token) { - - def is(tokenType: Int): Boolean = token.getType == tokenType - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala deleted file mode 100644 index 4f5ec2560ea3..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ /dev/null @@ -1,1223 +0,0 @@ -package io.joern.rubysrc2cpg.parser - -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* -import io.joern.rubysrc2cpg.parser.AntlrContextHelpers.* -import io.joern.rubysrc2cpg.parser.RubyParser.{CommandWithDoBlockContext, ConstantVariableReferenceContext} -import io.joern.rubysrc2cpg.passes.Defines -import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType -import io.joern.rubysrc2cpg.utils.FreshNameGenerator -import io.joern.x2cpg.Defines as XDefines -import org.antlr.v4.runtime.tree.{ParseTree, RuleNode} -import org.slf4j.LoggerFactory - -import scala.jdk.CollectionConverters.* - -/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST. - */ -class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { - - private val logger = LoggerFactory.getLogger(getClass) - private val classNameGen = FreshNameGenerator(id => s"") - - protected def freshClassName(span: TextSpan): SimpleIdentifier = { - SimpleIdentifier(None)(span.spanStart(classNameGen.fresh)) - } - - private def defaultTextSpan(code: String = ""): TextSpan = TextSpan(None, None, None, None, code) - - override def defaultResult(): RubyNode = Unknown()(defaultTextSpan()) - - override protected def shouldVisitNextChild(node: RuleNode, currentResult: RubyNode): Boolean = - currentResult.isInstanceOf[Unknown] - - override def visit(tree: ParseTree): RubyNode = { - Option(tree).map(super.visit).getOrElse(defaultResult()) - } - - override def visitProgram(ctx: RubyParser.ProgramContext): RubyNode = { - visit(ctx.compoundStatement()) - } - - override def visitCompoundStatement(ctx: RubyParser.CompoundStatementContext): RubyNode = { - StatementList(ctx.getStatements.map(visit))(ctx.toTextSpan) - } - - override def visitGroupingStatement(ctx: RubyParser.GroupingStatementContext): RubyNode = { - // When there's only 1 statement, we can use it directly, instead of wrapping it in a StatementList. - val statements = ctx.compoundStatement().getStatements.map(visit) - if (statements.size == 1) { - statements.head - } else { - StatementList(statements)(ctx.toTextSpan) - } - } - - override def visitStatements(ctx: RubyParser.StatementsContext): RubyNode = { - StatementList(ctx.statement().asScala.map(visit).toList)(ctx.toTextSpan) - } - - override def visitWhileExpression(ctx: RubyParser.WhileExpressionContext): RubyNode = { - val condition = visit(ctx.expressionOrCommand()) - val body = visit(ctx.doClause()) - WhileExpression(condition, body)(ctx.toTextSpan) - } - - override def visitUntilExpression(ctx: RubyParser.UntilExpressionContext): RubyNode = { - val condition = visit(ctx.expressionOrCommand()) - val body = visit(ctx.doClause()) - UntilExpression(condition, body)(ctx.toTextSpan) - } - - override def visitBeginEndExpression(ctx: RubyParser.BeginEndExpressionContext): RubyNode = { - visit(ctx.bodyStatement()) - } - - override def visitIfExpression(ctx: RubyParser.IfExpressionContext): RubyNode = { - val condition = visit(ctx.expressionOrCommand()) - val thenBody = visit(ctx.thenClause()) - val elsifs = ctx.elsifClause().asScala.map(visit).toList - val elseBody = Option(ctx.elseClause()).map(visit) - IfExpression(condition, thenBody, elsifs, elseBody)(ctx.toTextSpan) - } - - override def visitElsifClause(ctx: RubyParser.ElsifClauseContext): RubyNode = { - ElsIfClause(visit(ctx.expressionOrCommand()), visit(ctx.thenClause()))(ctx.toTextSpan) - } - - override def visitElseClause(ctx: RubyParser.ElseClauseContext): RubyNode = { - ElseClause(visit(ctx.compoundStatement()))(ctx.toTextSpan) - } - - override def visitUnlessExpression(ctx: RubyParser.UnlessExpressionContext): RubyNode = { - val condition = visit(ctx.expressionOrCommand()) - val thenBody = visit(ctx.thenClause()) - val elseBody = Option(ctx.elseClause()).map(visit) - UnlessExpression(condition, thenBody, elseBody)(ctx.toTextSpan) - } - - override def visitForExpression(ctx: RubyParser.ForExpressionContext): RubyNode = { - val forVariable = visit(ctx.forVariable()) - val iterableVariable = visit(ctx.commandOrPrimaryValue()) - val doBlock = visit(ctx.doClause()) - ForExpression(forVariable, iterableVariable, doBlock)(ctx.toTextSpan) - } - - override def visitForVariable(ctx: RubyParser.ForVariableContext): RubyNode = { - if (ctx.leftHandSide() != null) visit(ctx.leftHandSide()) - else visit(ctx.multipleLeftHandSide()) - } - - override def visitModifierStatement(ctx: RubyParser.ModifierStatementContext): RubyNode = { - ctx.statementModifier().getText match - case "if" => - val condition = visit(ctx.expressionOrCommand()) - val thenBody = visit(ctx.statement()) - val elsifs = List() - val elseBody = None - IfExpression(condition, thenBody, elsifs, elseBody)(ctx.toTextSpan) - case "unless" => - val condition = visit(ctx.expressionOrCommand()) - val thenBody = visit(ctx.statement()) - val elseBody = None - UnlessExpression(condition, thenBody, elseBody)(ctx.toTextSpan) - case "while" => - val condition = visit(ctx.expressionOrCommand()) - val body = visit(ctx.statement()) - WhileExpression(condition, body)(ctx.toTextSpan) - case "until" => - val condition = visit(ctx.expressionOrCommand()) - val body = visit(ctx.statement()) - DoWhileExpression(condition, body)(ctx.toTextSpan) - case "rescue" => - val body = visit(ctx.statement()) - val thenClause = visit(ctx.expressionOrCommand()) - val rescueClause = - RescueClause(Option.empty, Option.empty, thenClause)(ctx.toTextSpan) - val rescExp = - RescueExpression(body, List(rescueClause), Option.empty, Option.empty)(ctx.toTextSpan) - rescExp - case _ => - logger.warn(s"Unhandled modifier statement ${ctx.getClass} ${ctx.toTextSpan} ") - Unknown()(ctx.toTextSpan) - } - - override def visitTernaryOperatorExpression(ctx: RubyParser.TernaryOperatorExpressionContext): RubyNode = { - val condition = visit(ctx.operatorExpression(0)) - val thenBody = visit(ctx.operatorExpression(1)) - val elseBody = visit(ctx.operatorExpression(2)) - IfExpression( - condition, - thenBody, - List.empty, - Option(ElseClause(StatementList(elseBody :: Nil)(elseBody.span))(elseBody.span)) - )(ctx.toTextSpan) - } - - override def visitReturnMethodInvocationWithoutParentheses( - ctx: RubyParser.ReturnMethodInvocationWithoutParenthesesContext - ): RubyNode = { - val expressions = ctx.primaryValueList().primaryValue().asScala.map(visit).toList - ReturnExpression(expressions)(ctx.toTextSpan) - } - - override def visitNumericLiteral(ctx: RubyParser.NumericLiteralContext): RubyNode = { - if (ctx.hasSign) { - UnaryExpression(ctx.sign.getText, visit(ctx.unsignedNumericLiteral()))(ctx.toTextSpan) - } else { - visit(ctx.unsignedNumericLiteral()) - } - } - - override def visitUnaryExpression(ctx: RubyParser.UnaryExpressionContext): RubyNode = { - UnaryExpression(ctx.unaryOperator().getText, visit(ctx.primaryValue()))(ctx.toTextSpan) - } - - override def visitUnaryMinusExpression(ctx: RubyParser.UnaryMinusExpressionContext): RubyNode = { - UnaryExpression(ctx.MINUS().getText, visit(ctx.primaryValue()))(ctx.toTextSpan) - } - - override def visitNotExpressionOrCommand(ctx: RubyParser.NotExpressionOrCommandContext): RubyNode = { - UnaryExpression(ctx.NOT().getText, visit(ctx.expressionOrCommand()))(ctx.toTextSpan) - } - - override def visitCommandExpressionOrCommand(ctx: RubyParser.CommandExpressionOrCommandContext): RubyNode = { - val methodInvocation = visit(ctx.methodInvocationWithoutParentheses()) - if (Option(ctx.EMARK()).isDefined) { - UnaryExpression(ctx.EMARK().getText, methodInvocation)(ctx.toTextSpan) - } else { - methodInvocation - } - } - - override def visitCommandWithDoBlock(ctx: CommandWithDoBlockContext): RubyNode = { - val name = Option(ctx.methodIdentifier()).orElse(Option(ctx.methodName())).map(visit).getOrElse(defaultResult()) - val arguments = ctx.arguments.map(visit) - val block = visit(ctx.doBlock()).asInstanceOf[Block] - SimpleCallWithBlock(name, arguments, block)(ctx.toTextSpan) - } - - override def visitHereDocs(ctx: RubyParser.HereDocsContext): RubyNode = { - HereDocNode(ctx.hereDoc().getText)(ctx.toTextSpan) - } - - override def visitPrimaryOperatorExpression(ctx: RubyParser.PrimaryOperatorExpressionContext): RubyNode = { - super.visitPrimaryOperatorExpression(ctx) match { - case x: BinaryExpression if x.lhs.text.endsWith("=") && x.op == "*" => - // fixme: This workaround handles a parser ambiguity with method identifiers having `=` and assignments with - // splatting on the RHS. The Ruby parser gives precedence to assignments over methods called with this suffix - // however - val newLhs = x.lhs match { - case call: SimpleCall => SimpleIdentifier(None)(call.span.spanStart(call.span.text.stripSuffix("="))) - case y => - logger.warn(s"Unhandled class in repacking of primary operator expression ${y.getClass}") - y - } - val newRhs = { - val oldRhsSpan = x.rhs.span - SplattingRubyNode(x.rhs)(oldRhsSpan.spanStart(s"*${oldRhsSpan.text}")) - } - SingleAssignment(newLhs, "=", newRhs)(x.span) - case x => x - } - } - - override def visitPowerExpression(ctx: RubyParser.PowerExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.powerOperator.getText, visit(ctx.primaryValue(1)))(ctx.toTextSpan) - } - - override def visitAdditiveExpression(ctx: RubyParser.AdditiveExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.additiveOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitMultiplicativeExpression(ctx: RubyParser.MultiplicativeExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.multiplicativeOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitLogicalAndExpression(ctx: RubyParser.LogicalAndExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.andOperator.getText, visit(ctx.primaryValue(1)))(ctx.toTextSpan) - } - - override def visitLogicalOrExpression(ctx: RubyParser.LogicalOrExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.orOperator.getText, visit(ctx.primaryValue(1)))(ctx.toTextSpan) - } - - override def visitKeywordAndOrExpressionOrCommand( - ctx: RubyParser.KeywordAndOrExpressionOrCommandContext - ): RubyNode = { - BinaryExpression(visit(ctx.lhs), ctx.binOp.getText, visit(ctx.rhs))(ctx.toTextSpan) - } - - override def visitShiftExpression(ctx: RubyParser.ShiftExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.bitwiseShiftOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitBitwiseAndExpression(ctx: RubyParser.BitwiseAndExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.bitwiseAndOperator.getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitBitwiseOrExpression(ctx: RubyParser.BitwiseOrExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.bitwiseOrOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitRelationalExpression(ctx: RubyParser.RelationalExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.relationalOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitEqualityExpression(ctx: RubyParser.EqualityExpressionContext): RubyNode = { - BinaryExpression(visit(ctx.primaryValue(0)), ctx.equalityOperator().getText, visit(ctx.primaryValue(1)))( - ctx.toTextSpan - ) - } - - override def visitDecimalUnsignedLiteral(ctx: RubyParser.DecimalUnsignedLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Integer))(ctx.toTextSpan) - } - - override def visitBinaryUnsignedLiteral(ctx: RubyParser.BinaryUnsignedLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Integer))(ctx.toTextSpan) - } - - override def visitOctalUnsignedLiteral(ctx: RubyParser.OctalUnsignedLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Integer))(ctx.toTextSpan) - } - - override def visitHexadecimalUnsignedLiteral(ctx: RubyParser.HexadecimalUnsignedLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Integer))(ctx.toTextSpan) - } - - override def visitFloatWithExponentUnsignedLiteral( - ctx: RubyParser.FloatWithExponentUnsignedLiteralContext - ): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Float))(ctx.toTextSpan) - } - - override def visitFloatWithoutExponentUnsignedLiteral( - ctx: RubyParser.FloatWithoutExponentUnsignedLiteralContext - ): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Float))(ctx.toTextSpan) - } - - override def visitPureSymbolLiteral(ctx: RubyParser.PureSymbolLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Symbol))(ctx.toTextSpan) - } - - override def visitSingleQuotedSymbolLiteral(ctx: RubyParser.SingleQuotedSymbolLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.Symbol))(ctx.toTextSpan) - } - - override def visitNilPseudoVariable(ctx: RubyParser.NilPseudoVariableContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.NilClass))(ctx.toTextSpan) - } - - override def visitTruePseudoVariable(ctx: RubyParser.TruePseudoVariableContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.TrueClass))(ctx.toTextSpan) - } - - override def visitFalsePseudoVariable(ctx: RubyParser.FalsePseudoVariableContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.FalseClass))(ctx.toTextSpan) - } - - override def visitSingleQuotedStringExpression(ctx: RubyParser.SingleQuotedStringExpressionContext): RubyNode = { - if (!ctx.isInterpolated) { - StaticLiteral(getBuiltInType(Defines.String))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.String), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitQuotedNonExpandedStringLiteral(ctx: RubyParser.QuotedNonExpandedStringLiteralContext): RubyNode = { - StaticLiteral(getBuiltInType(Defines.String))(ctx.toTextSpan) - } - - override def visitDoubleQuotedStringExpression(ctx: RubyParser.DoubleQuotedStringExpressionContext): RubyNode = { - if (!ctx.isInterpolated) { - StaticLiteral(getBuiltInType(Defines.String))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.String), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitDoubleQuotedSymbolLiteral(ctx: RubyParser.DoubleQuotedSymbolLiteralContext): RubyNode = { - if (!ctx.isInterpolated) { - StaticLiteral(getBuiltInType(Defines.Symbol))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.Symbol), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitQuotedExpandedStringLiteral(ctx: RubyParser.QuotedExpandedStringLiteralContext): RubyNode = { - if (!ctx.isInterpolated) { - StaticLiteral(getBuiltInType(Defines.String))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.String), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitRegularExpressionLiteral(ctx: RubyParser.RegularExpressionLiteralContext): RubyNode = { - if (ctx.isStatic) { - StaticLiteral(getBuiltInType(Defines.Regexp))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.Regexp), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitQuotedExpandedRegularExpressionLiteral( - ctx: RubyParser.QuotedExpandedRegularExpressionLiteralContext - ): RubyNode = { - if (ctx.isStatic) { - StaticLiteral(getBuiltInType(Defines.Regexp))(ctx.toTextSpan) - } else { - DynamicLiteral(getBuiltInType(Defines.Regexp), ctx.interpolations.map(visit))(ctx.toTextSpan) - } - } - - override def visitCurlyBracesBlock(ctx: RubyParser.CurlyBracesBlockContext): RubyNode = { - val parameters = Option(ctx.blockParameter()).fold(List())(_.parameters).map(visit) - val body = visit(ctx.compoundStatement()) - Block(parameters, body)(ctx.toTextSpan) - } - - override def visitDoBlock(ctx: RubyParser.DoBlockContext): RubyNode = { - val parameters = Option(ctx.blockParameter()).fold(List())(_.parameters).map(visit) - val body = visit(ctx.bodyStatement()) - Block(parameters, body)(ctx.toTextSpan) - } - override def visitLocalVariableAssignmentExpression( - ctx: RubyParser.LocalVariableAssignmentExpressionContext - ): RubyNode = { - val lhs = visit(ctx.lhs) - val rhs = visit(ctx.rhs) - val op = ctx.assignmentOperator().getText - SingleAssignment(lhs, op, rhs)(ctx.toTextSpan) - } - - private def flattenStatementLists(x: List[RubyNode]): List[RubyNode] = { - x match { - case (head: StatementList) :: xs => head.statements ++ flattenStatementLists(xs) - case head :: tail => head +: flattenStatementLists(tail) - case Nil => Nil - } - } - - override def visitMultipleAssignmentStatement(ctx: RubyParser.MultipleAssignmentStatementContext): RubyNode = { - - /** Recursively expand and duplicate splatting nodes so that they line up with what they consume. - * - * @param nodes - * the splat nodes. - * @param expandSize - * how many more duplicates to create. - */ - def slurp(nodes: List[RubyNode], expandSize: Int): List[RubyNode] = nodes match { - case (head: SplattingRubyNode) :: tail if expandSize > 0 => head :: slurp(head :: tail, expandSize - 1) - case head :: tail => head :: slurp(tail, expandSize) - case Nil => List.empty - } - - val lhsNodes = Option(ctx.multipleLeftHandSide()) - .map(visit) - .orElse( - Option(ctx.leftHandSide()) - .map(visit) - .map(node => SplattingRubyNode(node)(node.span.spanStart(s"*${node.span.text}"))) - ) - .getOrElse(defaultResult()) match { - case x: StatementList => flattenStatementLists(x.statements) - case x => List(x) - } - val rhsNodes = Option(ctx.multipleRightHandSide()) - .map(visit) - .getOrElse(defaultResult()) match { - case x: StatementList => flattenStatementLists(x.statements) - case x => List(x) - } - val op = ctx.EQ().toString - - lazy val defaultAssignments = lhsNodes - .zipAll(rhsNodes, defaultResult(), Unknown()(defaultTextSpan(Defines.Undefined))) - .map { case (lhs, rhs) => SingleAssignment(lhs, op, rhs)(ctx.toTextSpan) } - - val assignments = if ((lhsNodes ++ rhsNodes).exists(_.isInstanceOf[SplattingRubyNode])) { - rhsNodes.size - lhsNodes.size match { - // Handle slurping the RHS values - case x if x > 0 => { - val slurpedLhs = slurp(lhsNodes, x) - - slurpedLhs - .zip(rhsNodes) - .groupBy(_._1) - .toSeq - .map { case (lhsNode, xs) => lhsNode -> xs.map(_._2) } - .sortBy { x => slurpedLhs.indexOf(x._1) } // groupBy produces a map which discards insertion order - .map { - case (SplattingRubyNode(lhs), rhss) => - SingleAssignment(lhs, op, ArrayLiteral(rhss)(ctx.toTextSpan))(ctx.toTextSpan) - case (lhs, rhs :: Nil) => SingleAssignment(lhs, op, rhs)(ctx.toTextSpan) - case (lhs, rhss) => SingleAssignment(lhs, op, ArrayLiteral(rhss)(ctx.toTextSpan))(ctx.toTextSpan) - } - .toList - } - // Handle splitting the RHS values - case x if x < 0 => { - val slurpedRhs = slurp(rhsNodes, Math.abs(x)) - - lhsNodes - .zip(slurpedRhs) - .groupBy(_._2) - .toSeq - .map { case (rhsNode, xs) => rhsNode -> xs.map(_._1) } - .sortBy { x => slurpedRhs.indexOf(x._1) } // groupBy produces a map which discards insertion order - .flatMap { - case (SplattingRubyNode(rhs), lhss) => - lhss.map(SingleAssignment(_, op, SplattingRubyNode(rhs)(rhs.span))(ctx.toTextSpan)) - case (rhs, lhs :: Nil) => Seq(SingleAssignment(lhs, op, rhs)(ctx.toTextSpan)) - case (rhs, lhss) => lhss.map(SingleAssignment(_, op, SplattingRubyNode(rhs)(rhs.span))(ctx.toTextSpan)) - } - .toList - } - case _ => defaultAssignments - } - } else { - defaultAssignments - } - MultipleAssignment(assignments)(ctx.toTextSpan) - } - - override def visitMultipleLeftHandSide(ctx: RubyParser.MultipleLeftHandSideContext): RubyNode = { - val multiLhsItems = ctx.multipleLeftHandSideItem.asScala.map(visit).toList - val packingLHSNodes = Option(ctx.packingLeftHandSide) - .map(visit) - .map { - case StatementList(statements) => statements - case x => List(x) - } - .getOrElse(List.empty) - val procParameter = Option(ctx.procParameter).map(visit).toList - val groupedLhs = Option(ctx.groupedLeftHandSide).map(visit).toList - val statements = multiLhsItems ++ packingLHSNodes ++ procParameter ++ groupedLhs - StatementList(statements)(ctx.toTextSpan) - } - - override def visitPackingLeftHandSide(ctx: RubyParser.PackingLeftHandSideContext): RubyNode = { - val splatNode = SplattingRubyNode(visit(ctx.leftHandSide))(ctx.toTextSpan) - Option(ctx.multipleLeftHandSideItem()).map(_.asScala.map(visit).toList).getOrElse(List.empty) match { - case Nil => splatNode - case xs => StatementList(splatNode +: xs)(ctx.toTextSpan) - } - } - - override def visitMultipleRightHandSide(ctx: RubyParser.MultipleRightHandSideContext): RubyNode = { - val rhsSplatting = Option(ctx.splattingRightHandSide()).map(_.splattingArgument()).map(visit).toList - Option(ctx.operatorExpressionList()) - .map(x => StatementList(x.operatorExpression().asScala.map(visit).toList ++ rhsSplatting)(ctx.toTextSpan)) - .getOrElse(defaultResult()) - } - - override def visitSplattingArgument(ctx: RubyParser.SplattingArgumentContext): RubyNode = { - SplattingRubyNode(visit(ctx.operatorExpression()))(ctx.toTextSpan) - } - - override def visitAttributeAssignmentExpression(ctx: RubyParser.AttributeAssignmentExpressionContext): RubyNode = { - val lhs = visit(ctx.primaryValue()) - val op = ctx.op.getText - val memberName = ctx.methodName.getText - val rhs = visit(ctx.operatorExpression()) - AttributeAssignment(lhs, op, memberName, rhs)(ctx.toTextSpan) - } - - override def visitSimpleCommand(ctx: RubyParser.SimpleCommandContext): RubyNode = { - if (Option(ctx.commandArgument()).map(_.getText).exists(_.startsWith("::"))) { - val memberName = ctx.commandArgument().getText.stripPrefix("::") - if (memberName.headOption.exists(_.isUpper)) { // Constant accesses are upper-case 1st letter - MemberAccess(visit(ctx.methodIdentifier()), "::", memberName)(ctx.toTextSpan) - } else { - MemberCall(visit(ctx.methodIdentifier()), "::", memberName, Nil)(ctx.toTextSpan) - } - } else if (!ctx.methodIdentifier().isAttrDeclaration) { - val identifierCtx = ctx.methodIdentifier() - val arguments = ctx.commandArgument().arguments.map(visit) - (identifierCtx.getText, arguments) match { - case ("require", List(argument)) => - RequireCall(visit(identifierCtx), argument)(ctx.toTextSpan) - case ("require_relative", List(argument)) => - RequireCall(visit(identifierCtx), argument, true)(ctx.toTextSpan) - case ("require_all", List(argument)) => - RequireCall(visit(identifierCtx), argument, true, true)(ctx.toTextSpan) - case ("include", List(argument)) => - IncludeCall(visit(identifierCtx), argument)(ctx.toTextSpan) - case (idAssign, arguments) if idAssign.endsWith("=") => - // fixme: This workaround handles a parser ambiguity with method identifiers having `=` and assignments. - // The Ruby parser gives precedence to assignments over methods called with this suffix however - val lhsIdentifier = SimpleIdentifier(None)(identifierCtx.toTextSpan.spanStart(idAssign.stripSuffix("="))) - val argNode = arguments match { - case arg :: Nil => arg - case xs => ArrayLiteral(xs)(ctx.commandArgument().toTextSpan) - } - SingleAssignment(lhsIdentifier, "=", argNode)(ctx.toTextSpan) - case _ => - SimpleCall(visit(identifierCtx), arguments)(ctx.toTextSpan) - } - } else { - FieldsDeclaration(ctx.commandArgument().arguments.map(visit))(ctx.toTextSpan) - } - } - - override def visitIsDefinedExpression(ctx: RubyParser.IsDefinedExpressionContext): RubyNode = { - SimpleCall(visit(ctx.isDefinedKeyword), visit(ctx.expressionOrCommand()) :: Nil)(ctx.toTextSpan) - } - - override def visitIsDefinedCommand(ctx: RubyParser.IsDefinedCommandContext): RubyNode = { - SimpleCall(visit(ctx.isDefinedKeyword), visit(ctx.primaryValue()) :: Nil)(ctx.toTextSpan) - } - - override def visitMethodCallExpression(ctx: RubyParser.MethodCallExpressionContext): RubyNode = { - SimpleCall(visit(ctx.methodOnlyIdentifier()), List())(ctx.toTextSpan) - } - - override def visitMethodCallWithBlockExpression(ctx: RubyParser.MethodCallWithBlockExpressionContext): RubyNode = { - ctx.methodIdentifier().getText match { - case Defines.Proc | Defines.Lambda => ProcOrLambdaExpr(visit(ctx.block()).asInstanceOf[Block])(ctx.toTextSpan) - case Defines.Loop => - DoWhileExpression( - SimpleIdentifier(Option(Defines.getBuiltInType(Defines.TrueClass)))( - ctx.methodIdentifier().toTextSpan.spanStart("true") - ), - ctx.block() match { - case b: RubyParser.DoBlockBlockContext => - visit(b.doBlock().bodyStatement()) - case y => - logger.warn(s"Unexpected loop block body ${y.getClass}") - visit(ctx.block()) - } - )(ctx.toTextSpan) - case _ => - SimpleCallWithBlock(visit(ctx.methodIdentifier()), List(), visit(ctx.block()).asInstanceOf[Block])( - ctx.toTextSpan - ) - } - } - - override def visitLambdaExpression(ctx: RubyParser.LambdaExpressionContext): RubyNode = { - val parameters = Option(ctx.parameterList()).fold(List())(_.parameters).map(visit) - val body = visit(ctx.block()) - ProcOrLambdaExpr(Block(parameters, body)(ctx.toTextSpan))(ctx.toTextSpan) - } - - override def visitMethodCallWithParenthesesExpression( - ctx: RubyParser.MethodCallWithParenthesesExpressionContext - ): RubyNode = { - if (Option(ctx.block()).isDefined) { - SimpleCallWithBlock( - visit(ctx.methodIdentifier()), - ctx.argumentWithParentheses().arguments.map(visit), - visit(ctx.block()).asInstanceOf[Block] - )(ctx.toTextSpan) - } else { - SimpleCall(visit(ctx.methodIdentifier()), ctx.argumentWithParentheses().arguments.map(visit))(ctx.toTextSpan) - } - } - - override def visitYieldExpression(ctx: RubyParser.YieldExpressionContext): RubyNode = { - val arguments = Option(ctx.argumentWithParentheses()).iterator.flatMap(_.arguments).map(visit).toList - YieldExpr(arguments)(ctx.toTextSpan) - } - - override def visitYieldMethodInvocationWithoutParentheses( - ctx: RubyParser.YieldMethodInvocationWithoutParenthesesContext - ): RubyNode = { - val arguments = ctx.primaryValueList().primaryValue().asScala.map(visit).toList - YieldExpr(arguments)(ctx.toTextSpan) - } - - override def visitMemberAccessCommand(ctx: RubyParser.MemberAccessCommandContext): RubyNode = { - val arg = visit(ctx.commandArgument()) - val methodName = visit(ctx.methodName()) - val base = visit(ctx.primary()) - MemberCall(base, ".", methodName.text, List(arg))(ctx.toTextSpan) - } - - override def visitConstantIdentifierVariable(ctx: RubyParser.ConstantIdentifierVariableContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitGlobalIdentifierVariable(ctx: RubyParser.GlobalIdentifierVariableContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitClassIdentifierVariable(ctx: RubyParser.ClassIdentifierVariableContext): RubyNode = { - ClassFieldIdentifier()(ctx.toTextSpan) - } - - override def visitInstanceIdentifierVariable(ctx: RubyParser.InstanceIdentifierVariableContext): RubyNode = { - InstanceFieldIdentifier()(ctx.toTextSpan) - } - - override def visitLocalIdentifierVariable(ctx: RubyParser.LocalIdentifierVariableContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitClassName(ctx: RubyParser.ClassNameContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitMethodIdentifier(ctx: RubyParser.MethodIdentifierContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitMethodOnlyIdentifier(ctx: RubyParser.MethodOnlyIdentifierContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitIsDefinedKeyword(ctx: RubyParser.IsDefinedKeywordContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) - } - - override def visitLinePseudoVariable(ctx: RubyParser.LinePseudoVariableContext): RubyNode = { - SimpleIdentifier(Some(getBuiltInType(Defines.Integer)))(ctx.toTextSpan) - } - - override def visitFilePseudoVariable(ctx: RubyParser.FilePseudoVariableContext): RubyNode = { - SimpleIdentifier(Some(getBuiltInType(Defines.String)))(ctx.toTextSpan) - } - - override def visitEncodingPseudoVariable(ctx: RubyParser.EncodingPseudoVariableContext): RubyNode = { - SimpleIdentifier(Some(getBuiltInType(Defines.Encoding)))(ctx.toTextSpan) - } - - override def visitSelfPseudoVariable(ctx: RubyParser.SelfPseudoVariableContext): RubyNode = { - SelfIdentifier()(ctx.toTextSpan) - } - - override def visitMemberAccessExpression(ctx: RubyParser.MemberAccessExpressionContext): RubyNode = { - val hasArguments = Option(ctx.argumentWithParentheses()).isDefined - val hasBlock = Option(ctx.block()).isDefined - val isClassDecl = Option(ctx.primaryValue()).map(_.getText).contains("Class") && Option(ctx.methodName()) - .map(_.getText) - .contains("new") - val methodName = ctx.methodName().getText - - if (!hasBlock) { - val target = visit(ctx.primaryValue()) - if (methodName == "new") { - if (!hasArguments) { - return SimpleObjectInstantiation(target, List.empty)(ctx.toTextSpan) - } else { - return SimpleObjectInstantiation(target, ctx.argumentWithParentheses().arguments.map(visit))(ctx.toTextSpan) - } - } else { - if (!hasArguments) { - if (methodName.headOption.exists(_.isUpper)) { - return MemberAccess(target, ctx.op.getText, methodName)(ctx.toTextSpan) - } else { - return MemberCall(target, ctx.op.getText, methodName, Nil)(ctx.toTextSpan) - } - } else { - return MemberCall(target, ctx.op.getText, methodName, ctx.argumentWithParentheses().arguments.map(visit))( - ctx.toTextSpan - ) - } - } - } - - if (hasBlock && isClassDecl) { - val block = visit(ctx.block()).asInstanceOf[Block] - return AnonymousClassDeclaration(freshClassName(ctx.primaryValue().toTextSpan), None, block.body)(ctx.toTextSpan) - } else if (hasBlock) { - val block = visit(ctx.block()).asInstanceOf[Block] - val target = visit(ctx.primaryValue()) - if (methodName == "new") { - if (!hasArguments) { - return ObjectInstantiationWithBlock(target, List.empty, block)(ctx.toTextSpan) - } else { - return ObjectInstantiationWithBlock(target, ctx.argumentWithParentheses().arguments.map(visit), block)( - ctx.toTextSpan - ) - } - } else { - return MemberCallWithBlock( - target, - ctx.op.getText, - methodName, - Option(ctx.argumentWithParentheses()).map(_.arguments).getOrElse(List()).map(visit), - visit(ctx.block()).asInstanceOf[Block] - )(ctx.toTextSpan) - } - } - - logger.warn(s"MemberAccessExpression not handled: '${ctx.toTextSpan}'") - Unknown()(ctx.toTextSpan) - } - - override def visitConstantVariableReference(ctx: ConstantVariableReferenceContext): RubyNode = { - MemberAccess(SelfIdentifier()(ctx.toTextSpan.spanStart(Defines.Self)), "::", ctx.CONSTANT_IDENTIFIER().getText)( - ctx.toTextSpan - ) - } - - override def visitIndexingAccessExpression(ctx: RubyParser.IndexingAccessExpressionContext): RubyNode = { - IndexAccess( - visit(ctx.primaryValue()), - Option(ctx.indexingArgumentList()).map(_.arguments).getOrElse(List()).map(visit) - )(ctx.toTextSpan) - } - - override def visitBracketedArrayLiteral(ctx: RubyParser.BracketedArrayLiteralContext): RubyNode = { - ArrayLiteral(Option(ctx.indexingArgumentList()).map(_.arguments).getOrElse(List()).map(visit))(ctx.toTextSpan) - } - - override def visitQuotedNonExpandedStringArrayLiteral( - ctx: RubyParser.QuotedNonExpandedStringArrayLiteralContext - ): RubyNode = { - val elements = Option(ctx.quotedNonExpandedArrayElementList()) - .map(_.elements) - .getOrElse(List()) - .map(elemCtx => StaticLiteral(getBuiltInType(Defines.String))(elemCtx.toTextSpan)) - ArrayLiteral(elements)(ctx.toTextSpan) - } - - override def visitQuotedNonExpandedSymbolArrayLiteral( - ctx: RubyParser.QuotedNonExpandedSymbolArrayLiteralContext - ): RubyNode = { - val elements = Option(ctx.quotedNonExpandedArrayElementList()) - .map(_.elements) - .getOrElse(List()) - .map(elemCtx => StaticLiteral(getBuiltInType(Defines.Symbol))(elemCtx.toTextSpan)) - ArrayLiteral(elements)(ctx.toTextSpan) - } - - override def visitRangeExpression(ctx: RubyParser.RangeExpressionContext): RubyNode = { - RangeExpression( - visit(ctx.primaryValue(0)), - visit(ctx.primaryValue(1)), - visit(ctx.rangeOperator()).asInstanceOf[RangeOperator] - )(ctx.toTextSpan) - } - - override def visitRangeOperator(ctx: RubyParser.RangeOperatorContext): RubyNode = { - RangeOperator(Option(ctx.DOT2()).isEmpty)(ctx.toTextSpan) - } - - override def visitHashLiteral(ctx: RubyParser.HashLiteralContext): RubyNode = { - HashLiteral(Option(ctx.associationList()).map(_.associations).getOrElse(List()).map(visit))(ctx.toTextSpan) - } - - override def visitAssociation(ctx: RubyParser.AssociationContext): RubyNode = { - ctx.associationKey().getText match { - case "if" => - Association(SimpleIdentifier()(ctx.toTextSpan.spanStart("if")), visit(ctx.operatorExpression()))(ctx.toTextSpan) - case _ => - Association(visit(ctx.associationKey()), visit(ctx.operatorExpression()))(ctx.toTextSpan) - } - } - - override def visitModuleDefinition(ctx: RubyParser.ModuleDefinitionContext): RubyNode = { - val (nonFieldStmts, fields) = genInitFieldStmts(ctx.bodyStatement()) - - val moduleName = visit(ctx.classPath()) - val memberCall = createBodyMemberCall(moduleName.span.text, ctx.toTextSpan) - - ModuleDeclaration(visit(ctx.classPath()), nonFieldStmts, fields, Option(memberCall))(ctx.toTextSpan) - } - - override def visitSingletonClassDefinition(ctx: RubyParser.SingletonClassDefinitionContext): RubyNode = { - SingletonClassDeclaration( - freshClassName(ctx.toTextSpan), - Option(ctx.commandOrPrimaryValueClass()).map(visit), - visit(ctx.bodyStatement()) - )(ctx.toTextSpan) - } - - private def findFieldsInMethodDecls(methodDecls: List[MethodDeclaration]): List[RubyNode & RubyFieldIdentifier] = { - // TODO: Handle case where body of method is not a StatementList - methodDecls - .flatMap { x => - x.body match { - case stmtList: StatementList => - stmtList.statements.collect { case x: SingleAssignment => - x.lhs - } - case _ => List.empty - } - } - .collect { case x: (RubyNode & RubyFieldIdentifier) => - x - } - } - - private def genInitFieldStmts( - ctxBodyStatement: RubyParser.BodyStatementContext - ): (RubyNode, List[RubyNode & RubyFieldIdentifier]) = { - val loweredClassDecls = lowerSingletonClassDeclarations(ctxBodyStatement) - - /** Generates SingleAssignment RubyNodes for list of fields and fields found in method decls - */ - def genSingleAssignmentStmtList( - fields: List[RubyNode], - fieldsInMethodDecls: List[RubyNode] - ): List[SingleAssignment] = { - (fields ++ fieldsInMethodDecls).map { x => - SingleAssignment(x, "=", StaticLiteral(getBuiltInType(Defines.NilClass))(x.span.spanStart("nil")))( - x.span.spanStart(s"${x.span.text} = nil") - ) - } - } - - /** Partition RubyFields into InstanceFieldIdentifiers and ClassFieldIdentifiers - */ - def partitionRubyFields(fields: List[RubyNode]): (List[RubyNode], List[RubyNode]) = { - fields.partition { - case _: InstanceFieldIdentifier => true - case _ => false - } - } - - loweredClassDecls match { - case stmtList: StatementList => - val (rubyFieldIdentifiers, otherStructures) = stmtList.statements.partition { - case x: (RubyNode & RubyFieldIdentifier) => true - case _ => false - } - val (fieldAssignments, rest) = otherStructures - .map { - case x @ SingleAssignment(lhs: SimpleIdentifier, op, rhs) => - SingleAssignment(ClassFieldIdentifier()(lhs.span), op, rhs)(x.span) - case x @ SingleAssignment(lhs: RubyFieldIdentifier, op, rhs) => - // Perhaps non-intuitive, but @ fields assigned under a type belong to the singleton class - SingleAssignment(ClassFieldIdentifier()(lhs.span), op, rhs)(x.span) - case x => x - } - .partition { - case x: SingleAssignment => true - case _ => false - } - - val (instanceFields, classFields) = partitionRubyFields(rubyFieldIdentifiers) - - val methodDecls = rest.collect { case x: MethodDeclaration => - x - } - - val fieldsInMethodDecls = findFieldsInMethodDecls(methodDecls) - - val (instanceFieldsInMethodDecls, classFieldsInMethodDecls) = partitionRubyFields(fieldsInMethodDecls) - - val initializeMethod = methodDecls.collectFirst { case x if x.methodName == Defines.Initialize => x } - - val initStmtListStatements = genSingleAssignmentStmtList(instanceFields, instanceFieldsInMethodDecls) - val clinitStmtList = genSingleAssignmentStmtList(classFields, classFieldsInMethodDecls) ++ fieldAssignments - - val bodyMethodStmtList = - StatementList(initStmtListStatements ++ clinitStmtList)( - stmtList.span - .spanStart(initStmtListStatements.map(_.span.text).concat(clinitStmtList.map(_.span.text)).mkString("\n")) - ) - - val bodyMethod = MethodDeclaration(Defines.TypeDeclBody, List.empty, bodyMethodStmtList)( - stmtList.span.spanStart(s"def \n${bodyMethodStmtList.span.text}\nend") - ) - - val combinedFields = rubyFieldIdentifiers ++ fieldsInMethodDecls ++ - fieldAssignments.collect { case SingleAssignment(lhs: RubyFieldIdentifier, _, _) => lhs } - - ( - StatementList(bodyMethod +: rest)(bodyMethod.span), - combinedFields.asInstanceOf[List[RubyNode & RubyFieldIdentifier]] - ) - case decls => (decls, List.empty) - } - } - - /** Detects the alias statements and creates methods that reference the aliased method as a call. - * @param classBody - * the class body node - * @return - * the class body as a statement list. - */ - private def lowerAliasStatementsToMethods(classBody: RubyNode): StatementList = { - - val classBodyStmts = classBody match { - case StatementList(stmts) => stmts - case x => List(x) - } - - val methodParamMap = classBodyStmts.collect { case method: MethodDeclaration => - method.methodName -> method.parameters - }.toMap - - val loweredMethods = classBodyStmts.collect { case alias: AliasStatement => - methodParamMap.get(alias.oldName) match { - case Some(aliasingMethodParams) => - val argsCode = aliasingMethodParams.map(_.text).mkString(", ") - val callCode = s"${alias.oldName}($argsCode)" - MethodDeclaration( - alias.newName, - aliasingMethodParams, - StatementList( - SimpleCall( - SimpleIdentifier(None)(alias.span.spanStart(alias.oldName)), - aliasingMethodParams.map { x => SimpleIdentifier(None)(alias.span.spanStart(x.span.text)) } - )(alias.span.spanStart(callCode)) :: Nil - )(alias.span.spanStart(callCode)) - )(alias.span.spanStart(s"def ${alias.newName}($argsCode)")) - case None => - logger.warn( - s"Unable to correctly lower aliased method ${alias.oldName}, the result will be in degraded parameter/argument flows" - ) - MethodDeclaration( - alias.newName, - List.empty, - StatementList( - SimpleCall(SimpleIdentifier(None)(alias.span.spanStart(alias.oldName)), List.empty)(alias.span) :: Nil - )(alias.span) - )(alias.span) - } - } - - StatementList(classBodyStmts.filterNot(_.isInstanceOf[AliasStatement]) ++ loweredMethods)(classBody.span) - } - - /** Moves children nodes not allowed directly under TypeDecl to the `initialize` method - * @param stmts - * \- StatementList for ClassDecl - * @return - * - `initialize` MethodDeclaration with all non-allowed children nodes added - * - list of all nodes allowed directly under type decl - */ - private def filterNonAllowedTypeDeclChildren(stmts: StatementList): RubyNode = { - val (initMethod, nonInitStmts) = stmts.statements.partition { - case x: MethodDeclaration if x.methodName == Defines.Initialize => true - case _ => false - } - - val (allowedTypeDeclChildren, nonAllowedTypeDeclChildren) = nonInitStmts.partition { - case x: AllowedTypeDeclarationChild => true - case _ => false - } - - val (bodyMethod, otherTypeDeclChildren) = allowedTypeDeclChildren.partition { - case x: MethodDeclaration if x.methodName == Defines.TypeDeclBody => true - case _ => false - } - - val updatedBodyMethod = bodyMethod - .asInstanceOf[List[MethodDeclaration]] - .map { x => - val methodDeclStmts = - StatementList(x.body.asInstanceOf[StatementList].statements ++ nonAllowedTypeDeclChildren)( - x.span.spanStart(s"${x.body.span.text}${nonAllowedTypeDeclChildren.map(_.span.text).mkString("\n")}") - ) - - MethodDeclaration(x.methodName, x.parameters, methodDeclStmts)( - x.span.spanStart(s"def \n${methodDeclStmts.span.text}\nend") - ) - } - - StatementList(otherTypeDeclChildren ++ updatedBodyMethod)(stmts.span) - } - - override def visitClassDefinition(ctx: RubyParser.ClassDefinitionContext): RubyNode = { - val (nonFieldStmts, fields) = genInitFieldStmts(ctx.bodyStatement()) - - val stmts = lowerAliasStatementsToMethods(nonFieldStmts) - - val classBody = filterNonAllowedTypeDeclChildren(stmts) - val className = visit(ctx.classPath()) - - val memberCall = createBodyMemberCall(className.span.text, ctx.toTextSpan) - - ClassDeclaration( - visit(ctx.classPath()), - Option(ctx.commandOrPrimaryValueClass()).map(visit), - classBody, - fields, - Option(memberCall) - )(ctx.toTextSpan) - } - - private def createBodyMemberCall(name: String, textSpan: TextSpan): MemberCall = { - MemberCall( - MemberAccess(SelfIdentifier()(textSpan.spanStart(Defines.Self)), "::", name)( - textSpan.spanStart(s"${Defines.Self}::$name") - ), - "::", - Defines.TypeDeclBody, - List.empty - )(textSpan.spanStart(s"${Defines.Self}::$name::")) - } - - /** Lowers all MethodDeclaration found in SingletonClassDeclaration to SingletonMethodDeclaration. - * @param ctx - * body context from class definitions - * @return - * RubyNode with lowered MethodDeclarations where required - */ - private def lowerSingletonClassDeclarations(ctx: RubyParser.BodyStatementContext): RubyNode = { - visit(ctx) match { - case stmtList: StatementList => - StatementList(stmtList.statements.flatMap { - case singletonClassDeclaration: SingletonClassDeclaration => - singletonClassDeclaration.baseClass match { - case Some(selfIdentifier: SelfIdentifier) => - singletonClassDeclaration.body match { - case singletonClassStmtList: StatementList => - singletonClassStmtList.statements.map { - case method: MethodDeclaration => - SingletonMethodDeclaration(selfIdentifier, method.methodName, method.parameters, method.body)( - method.span - ) - case nonMethodStatement => nonMethodStatement - } - case singletonBody => singletonBody :: Nil - } - case _ => singletonClassDeclaration.body :: Nil - } - case nonStmtListBody => nonStmtListBody :: Nil - })(stmtList.span) - case nonStmtList => nonStmtList - } - } - - override def visitMethodDefinition(ctx: RubyParser.MethodDefinitionContext): RubyNode = { - MethodDeclaration( - ctx.definedMethodName().getText, - Option(ctx.methodParameterPart().parameterList()).fold(List())(_.parameters).map(visit), - visit(ctx.bodyStatement()) - )(ctx.toTextSpan) - } - - override def visitEndlessMethodDefinition(ctx: RubyParser.EndlessMethodDefinitionContext): RubyNode = { - val body = visit(ctx.statement()) match { - case x: StatementList => x - case x => StatementList(x :: Nil)(x.span) - } - MethodDeclaration( - ctx.definedMethodName().getText, - Option(ctx.parameterList()).fold(List())(_.parameters).map(visit), - body - )(ctx.toTextSpan) - } - - override def visitSingletonMethodDefinition(ctx: RubyParser.SingletonMethodDefinitionContext): RubyNode = { - SingletonMethodDeclaration( - visit(ctx.singletonObject()), - ctx.definedMethodName().getText, - Option(ctx.methodParameterPart().parameterList()).fold(List())(_.parameters).map(visit), - visit(ctx.bodyStatement()) - )(ctx.toTextSpan) - } - - override def visitProcParameter(ctx: RubyParser.ProcParameterContext): RubyNode = { - ProcParameter( - Option(ctx.procParameterName).map(_.LOCAL_VARIABLE_IDENTIFIER()).map(_.getText()).getOrElse(ctx.getText()) - )(ctx.toTextSpan) - } - - override def visitHashParameter(ctx: RubyParser.HashParameterContext): RubyNode = { - HashParameter(Option(ctx.LOCAL_VARIABLE_IDENTIFIER()).map(_.getText).getOrElse(ctx.getText))(ctx.toTextSpan) - } - - override def visitArrayParameter(ctx: RubyParser.ArrayParameterContext): RubyNode = { - ArrayParameter(Option(ctx.LOCAL_VARIABLE_IDENTIFIER()).map(_.getText).getOrElse(ctx.getText))(ctx.toTextSpan) - } - - override def visitOptionalParameter(ctx: RubyParser.OptionalParameterContext): RubyNode = { - OptionalParameter( - ctx.optionalParameterName().LOCAL_VARIABLE_IDENTIFIER().toString, - visit(ctx.operatorExpression()) - )(ctx.toTextSpan) - } - - override def visitMandatoryParameter(ctx: RubyParser.MandatoryParameterContext): RubyNode = { - MandatoryParameter(ctx.LOCAL_VARIABLE_IDENTIFIER().toString)(ctx.toTextSpan) - } - - override def visitVariableLeftHandSide(ctx: RubyParser.VariableLeftHandSideContext): RubyNode = { - if (Option(ctx.primary()).isEmpty) { - MandatoryParameter(ctx.toTextSpan.text)(ctx.toTextSpan) - } else { - logger.warn(s"Variable LHS without primary expression is not handled: '${ctx.toTextSpan}'") - Unknown()(ctx.toTextSpan) - } - } - - override def visitBodyStatement(ctx: RubyParser.BodyStatementContext): RubyNode = { - val body = visit(ctx.compoundStatement()) - val rescueClauses = - Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList).collect { case x: RescueClause => x } - val elseClause = Option(ctx.elseClause).map(visit).collect { case x: ElseClause => x } - val ensureClause = Option(ctx.ensureClause).map(visit).collect { case x: EnsureClause => x } - - if (rescueClauses.isEmpty && elseClause.isEmpty && ensureClause.isEmpty) { - visit(ctx.compoundStatement()) - } else { - RescueExpression(body, rescueClauses, elseClause, ensureClause)(ctx.toTextSpan) - } - } - - override def visitExceptionClassList(ctx: RubyParser.ExceptionClassListContext): RubyNode = { - Option(ctx.multipleRightHandSide()).map(visitMultipleRightHandSide).getOrElse(visit(ctx.operatorExpression())) - } - - override def visitRescueClause(ctx: RubyParser.RescueClauseContext): RubyNode = { - val exceptionClassList = Option(ctx.exceptionClassList).map(visit) - val variables = Option(ctx.exceptionVariableAssignment).map(visit) - val thenClause = visit(ctx.thenClause) - RescueClause(exceptionClassList, variables, thenClause)(ctx.toTextSpan) - } - - override def visitEnsureClause(ctx: RubyParser.EnsureClauseContext): RubyNode = { - EnsureClause(visit(ctx.compoundStatement()))(ctx.toTextSpan) - } - - override def visitCaseWithExpression(ctx: RubyParser.CaseWithExpressionContext): RubyNode = { - val expression = Option(ctx.expressionOrCommand()).map(visit) - val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList) - val elseClause = Option(ctx.elseClause()).map(visit) - CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan) - } - - override def visitCaseWithoutExpression(ctx: RubyParser.CaseWithoutExpressionContext): RubyNode = { - val expression = None - val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList) - val elseClause = Option(ctx.elseClause()).map(visit) - CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan) - } - - override def visitWhenClause(ctx: RubyParser.WhenClauseContext): RubyNode = { - val whenArgs = ctx.whenArgument() - val matchArgs = - Option(whenArgs.operatorExpressionList()).iterator.flatMap(_.operatorExpression().asScala).map(visit).toList - val matchSplatArg = Option(whenArgs.splattingArgument()).map(visit) - val thenClause = visit(ctx.thenClause()) - WhenClause(matchArgs, matchSplatArg, thenClause)(ctx.toTextSpan) - } - - override def visitAssociationKey(ctx: RubyParser.AssociationKeyContext): RubyNode = { - if (Option(ctx.operatorExpression()).isDefined) { - visit(ctx.operatorExpression()) - } else { - SimpleIdentifier()(ctx.toTextSpan) - } - } - - override def visitAliasStatement(ctx: RubyParser.AliasStatementContext): RubyNode = { - AliasStatement(ctx.oldName.getText, ctx.newName.getText)(ctx.toTextSpan) - } - - override def visitBreakWithoutArguments(ctx: RubyParser.BreakWithoutArgumentsContext): RubyNode = { - BreakStatement()(ctx.toTextSpan) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/AstCreationPass.scala index 234a844522df..8e43529dc25e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/AstCreationPass.scala @@ -1,5 +1,6 @@ package io.joern.rubysrc2cpg.passes +import flatgraph.DiffGraphApplier import io.joern.rubysrc2cpg.astcreation.AstCreator import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.NodeTypes @@ -7,7 +8,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import io.shiftleft.passes.ForkJoinParallelCpgPass import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate class AstCreationPass(cpg: Cpg, astCreators: List[AstCreator]) extends ForkJoinParallelCpgPass[AstCreator](cpg) { @@ -32,7 +32,7 @@ class AstCreationPass(cpg: Cpg, astCreators: List[AstCreator]) extends ForkJoinP .astParentFullName(NamespaceTraversal.globalNamespaceName) .isExternal(true) diffGraph.addNode(emptyType).addNode(anyType) - BatchedUpdate.applyDiff(cpg.graph, diffGraph) + DiffGraphApplier.applyDiff(cpg.graph, diffGraph) } override def runOnPart(diffGraph: DiffGraphBuilder, astCreator: AstCreator): Unit = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPass.scala index 9a1f5b4987ff..07fdee3faa2b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPass.scala @@ -16,5 +16,16 @@ class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg) { case None => Seq() } - override protected val configFileFilters: List[File => Boolean] = List(validGemfilePaths.contains) + override protected val configFileFilters: List[File => Boolean] = List( + // Gemfiles + validGemfilePaths.contains, + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".yml"), + // XML files + extensionFilter(".xml"), + // ERB files + extensionFilter(".erb") + ) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala index c571804387f3..0e56fe443606 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala @@ -1,14 +1,20 @@ package io.joern.rubysrc2cpg.passes +import org.slf4j.LoggerFactory + object Defines { + private val logger = LoggerFactory.getLogger(getClass) + val Any: String = "ANY" + val Defined: String = "defined" val Undefined: String = "Undefined" val Object: String = "Object" val NilClass: String = "NilClass" val TrueClass: String = "TrueClass" val FalseClass: String = "FalseClass" val Numeric: String = "Numeric" + val New: String = "new" val Integer: String = "Integer" val Float: String = "Float" val String: String = "String" @@ -21,30 +27,43 @@ object Defines { val Proc: String = "proc" val Loop: String = "loop" val Self: String = "self" + val Super: String = "super" + val Rational: String = "Rational" val Initialize: String = "initialize" val TypeDeclBody: String = "" - val Program: String = ":program" + val Main: String = "
" val Resolver: String = "" - val AnonymousProcParameter = "" + def prefixAsKernelDefined(typeInString: String): String = { + if (GlobalTypes.bundledClasses.contains(typeInString)) + logger.warn(s"Type '$typeInString' is considered a 'core' type, not a 'Kernel-contained' type") + s"${GlobalTypes.kernelPrefix}.$typeInString" + } - def getBuiltInType(typeInString: String) = s"${GlobalTypes.kernelPrefix}.$typeInString" + def prefixAsCoreType(typeInString: String): String = { + if (!GlobalTypes.bundledClasses.contains(typeInString)) + logger.warn(s"Type '$typeInString' not considered a 'core' type") + s"${GlobalTypes.corePrefix}.$typeInString" + } object RubyOperators { - val hashInitializer = ".hashInitializer" - val association = ".association" - val splat = ".splat" - val regexpMatch = "=~" - val regexpNotMatch = "!~" + val backticks: String = ".backticks" + val hashInitializer = ".hashInitializer" + val association = ".association" + val splat = ".splat" + val regexpMatch = "=~" + val regexpNotMatch = "!~" + + val regexMethods = Set("match", "sub", "gsub") } } object GlobalTypes { - val Kernel = "Kernel" - val builtinPrefix = "__core" - val kernelPrefix = s"$builtinPrefix.$Kernel" + val Kernel = "Kernel" + val corePrefix = "__core" + val kernelPrefix = s"$corePrefix.$Kernel" /** Source: https://ruby-doc.org/docs/ruby-doc-bundle/Manual/man-1.4/function.html */ @@ -192,7 +211,8 @@ object GlobalTypes { /* Source: https://ruby-doc.org/3.2.2/Kernel.html * - * We comment-out methods that require an explicit "receiver" (target of member access.) + * We comment-out methods that require an explicit "receiver" (target of member access) and those that may be commonly + * shadowed. */ val kernelFunctions: Set[String] = Set( "Array", @@ -252,7 +272,7 @@ object GlobalTypes { "require", "require_all", "require_relative", - "select", +// "select", "set_trace_func", "sleep", "spawn", diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencyPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencyPass.scala index 9413191ec7c6..d7fd9648a42f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencyPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencyPass.scala @@ -1,5 +1,6 @@ package io.joern.rubysrc2cpg.passes +import flatgraph.DiffGraphBuilder import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{ConfigFile, NewDependency} import io.shiftleft.passes.ForkJoinParallelCpgPass @@ -12,6 +13,18 @@ import io.shiftleft.semanticcpg.language.* */ class DependencyPass(cpg: Cpg) extends ForkJoinParallelCpgPass[ConfigFile](cpg) { + /** Adds all necessary initial core gems. + */ + override def init(): Unit = { + val diffGraph = Cpg.newDiffGraphBuilder + DependencyPass.CORE_GEMS + .map { coreGemName => + NewDependency().name(coreGemName).version(DependencyPass.CORE_GEM_VERSION) + } + .foreach(diffGraph.addNode) + flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph) + } + /** @return * the Gemfiles, while preferring `Gemfile.lock` files if present. */ @@ -88,3 +101,107 @@ class DependencyPass(cpg: Cpg) extends ForkJoinParallelCpgPass[ConfigFile](cpg) } } + +object DependencyPass { + val CORE_GEM_VERSION: String = "3.0.0" + // Scraped from: https://ruby-doc.org/stdlib-$CORE_GEM_VERSION/ + // These gems require explicit import but no entry required in `Gemsfile` + val CORE_GEMS: Set[String] = Set( + "abbrev", + "base64", + "benchmark", + "bigdecimal", + "bundler", + "cgi", + "coverage", + "csv", + "date", + "dbm", + "debug", + "delegate", + "did_you_mean", + "digest", + "drb", + "English", + "erb", + "etc", + "extmk", + "fcntl", + "fiddle", + "fileutils", + "find", + "forwardable", + "gdbm", + "getoptlong", + "io/console", + "io/nonblock", + "io/wait", + "ipaddr", + "irb", + "json", + "logger", + "matrix", + "minitest", + "mkmf", + "monitor", + "mutex_m", + "net/ftp", + "net/http", + "net/imap", + "net/pop", + "net/protocol", + "net/smtp", + "nkf", + "objspace", + "observer", + "open-uri", + "open3", + "openssl", + "optparse", + "ostruct", + "pathname", + "power_assert", + "pp", + "prettyprint", + "prime", + "pstore", + "psych", + "pty", + "racc", + "racc/parser", + "rake", + "rbs", + "readline", + "readline", + "reline", + "resolv", + "resolv-replace", + "rexml", + "rinda", + "ripper", + "rss", + "rubygems", + "securerandom", + "set", + "shellwords", + "singleton", + "socket", + "stringio", + "strscan", + "syslog", + "tempfile", + "test-unit", + "time", + "timeout", + "tmpdir", + "tracer", + "tsort", + "typeprof", + "un", + "uri", + "weakref", + "win32ole", + "yaml", + "zlib" + ) +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencySummarySolverPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencySummarySolverPass.scala index 30f4908bbad0..f5bb6e11e890 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencySummarySolverPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/DependencySummarySolverPass.scala @@ -16,7 +16,7 @@ class DependencySummarySolverPass(cpg: Cpg, dependencySummary: RubyProgramSummar override def runOnPart(diffGraph: DiffGraphBuilder, dependency: Dependency): Unit = { dependencySummary.namespaceToType.filter(_._1.startsWith(dependency.name)).flatMap(_._2).foreach { x => val typeDeclName = - if x.name.endsWith(RDefines.Program) then RDefines.Program + if x.name.endsWith(RDefines.Main) then RDefines.Main else x.name.split("[.]").lastOption.getOrElse(Defines.Unknown) val dependencyTypeDecl = TypeDeclStubCreator diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImplicitRequirePass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImplicitRequirePass.scala deleted file mode 100644 index f556b25457fb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImplicitRequirePass.scala +++ /dev/null @@ -1,103 +0,0 @@ -package io.joern.rubysrc2cpg.passes - -import io.joern.rubysrc2cpg.datastructures.{RubyProgramSummary, RubyType} -import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, EdgeTypes, Operators} -import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language.* -import org.apache.commons.text.CaseUtils - -import scala.collection.mutable - -/** In some Ruby frameworks, it is common to have an autoloader library that implicitly loads requirements onto the - * stack. This pass makes these imports explicit. The most popular one is Zeitwerk which we check in `Gemsfile.lock` to enable this pass. - */ -class ImplicitRequirePass(cpg: Cpg, programSummary: RubyProgramSummary) extends ForkJoinParallelCpgPass[Method](cpg) { - - private val importCallName: String = "require" - private val typeToPath = mutable.Map.empty[String, String] - - override def init(): Unit = { - programSummary.pathToType - .map { case (path, types) => - // zeitwerk will match types that share the name of the path. - // This match is insensitive to camel case, i.e, foo_bar will match type FooBar. - val fileName = path.split('/').last - path -> types.filter { t => - val typeName = t.name.split("[.]").last - typeName == fileName || typeName == CaseUtils.toCamelCase(fileName, true, '_', '-') - } - } - .foreach { case (path, types) => - types.foreach { typ => typeToPath.put(typ.name, path) } - } - } - - override def generateParts(): Array[Method] = - cpg.method.isModule.whereNot(_.astChildren.isCall.nameExact(importCallName)).toArray - - /** Collects methods within a module. - */ - private def findMethodsViaAstChildren(module: Method): Iterator[Method] = { - Iterator(module) ++ module.astChildren.flatMap { - case x: TypeDecl => x.method.flatMap(findMethodsViaAstChildren) - case x: Method => Iterator(x) ++ x.astChildren.collectAll[Method].flatMap(findMethodsViaAstChildren) - case _ => Iterator.empty - } - } - - override def runOnPart(builder: DiffGraphBuilder, part: Method): Unit = { - findMethodsViaAstChildren(part).ast.isCall - .flatMap { - case x if x.name == Operators.alloc => - x.argument.isIdentifier - case x => - x.receiver.fieldAccess.fieldIdentifier - } - .map { - case fi: FieldIdentifier => fi -> programSummary.matchingTypes(fi.canonicalName) - case i: Identifier => i -> programSummary.matchingTypes(i.name) - } - .distinct - .foreach { case (identifier, rubyTypes) => - val requireCalls = rubyTypes.flatMap { rubyType => - typeToPath.get(rubyType.name) match { - case Some(path) - if identifier.file.name - .map(_.replace("\\", "/")) - .headOption - .exists(x => rubyType.name.startsWith(x)) => - None // do not add an import to a file that defines the type - case Some(path) => Option(createRequireCall(builder, rubyType, path)) - case None => None - } - } - val startIndex = part.block.astChildren.size - requireCalls.zipWithIndex.foreach { case (call, idx) => - call.order(startIndex + idx) - builder.addEdge(part.block, call, EdgeTypes.AST) - } - } - } - - private def createRequireCall(builder: DiffGraphBuilder, rubyType: RubyType, path: String): NewCall = { - val requireCallNode = NewCall() - .name(importCallName) - .code(s"$importCallName '$path'") - .methodFullName(s"__builtin:$importCallName") - .dispatchType(DispatchTypes.DYNAMIC_DISPATCH) - .typeFullName(Defines.Any) - val receiverIdentifier = - NewIdentifier().name(importCallName).code(importCallName).typeFullName(Defines.Any).argumentIndex(0).order(1) - val pathLiteralNode = NewLiteral().code(s"'$path'").typeFullName("__builtin.String").argumentIndex(1).order(2) - builder.addNode(requireCallNode) - builder.addEdge(requireCallNode, receiverIdentifier, EdgeTypes.AST) - builder.addEdge(requireCallNode, receiverIdentifier, EdgeTypes.ARGUMENT) - builder.addEdge(requireCallNode, receiverIdentifier, EdgeTypes.RECEIVER) - builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.AST) - builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.ARGUMENT) - requireCallNode - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/DependencyDownloader.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/DependencyDownloader.scala index 0b6497bad368..958700b71089 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/DependencyDownloader.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/DependencyDownloader.scala @@ -2,7 +2,8 @@ package io.joern.rubysrc2cpg.utils import better.files.File import io.joern.rubysrc2cpg.datastructures.RubyProgramSummary -import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.parser.RubyAstGenRunner +import io.joern.rubysrc2cpg.passes.{Defines, DependencyPass} import io.joern.rubysrc2cpg.{Config, RubySrc2Cpg, parser} import io.joern.x2cpg.utils.ConcurrentTaskUtil import io.shiftleft.codepropertygraph.generated.Cpg @@ -34,10 +35,15 @@ class DependencyDownloader(cpg: Cpg) { */ def download(): RubyProgramSummary = { File.temporaryDirectory("joern-rubysrc2cpg").apply { dir => - cpg.dependency.filterNot(_.name == Defines.Resolver).foreach { dependency => - Try(Thread.sleep(100)) // Rate limit - downloadDependency(dir, dependency) - } + cpg.dependency + .filterNot(dep => + dep.name == Defines.Resolver || + (DependencyPass.CORE_GEMS.contains(dep.name) && DependencyPass.CORE_GEM_VERSION == dep.version) + ) + .foreach { dependency => + Try(Thread.sleep(100)) // Rate limit + downloadDependency(dir, dependency) + } untarDependencies(dir) summarizeDependencies(dir / "lib") } @@ -204,16 +210,24 @@ class DependencyDownloader(cpg: Cpg) { RubyProgramSummary(libSummary.namespaceToType, pathMappings) } - Using.resource(new parser.ResourceManagedParser(0.8)) { parser => + val tmpDir = File.newTemporaryDirectory("rubysrc2cpgOut") + try { + val config = Config().withDisableFileContent(true) + val astGenResult = RubyAstGenRunner(Config().withInputPath(targetDir.toString)).execute(tmpDir) + val astCreators = ConcurrentTaskUtil .runUsingThreadPool( - RubySrc2Cpg - .generateParserTasks(parser, Config().withInputPath(targetDir.pathAsString), Option(targetDir.pathAsString)) + RubySrc2Cpg.processAstGenRunnerResults(astGenResult.parsedFiles, config, Option(targetDir.toString)) ) .flatMap { - case Failure(exception) => logger.warn(s"Could not parse file, skipping - ", exception); None + case Failure(exception) => logger.warn(s"Unable to parse Ruby file, skipping -", exception); None case Success(astCreator) => Option(astCreator) } + .filter(x => { + if x.fileContent.isBlank then logger.info(s"File content empty, skipping - ${x.fileName}") + !x.fileContent.isBlank + }) + // Pre-parse the AST creators for high level structures val librarySummaries = ConcurrentTaskUtil .runUsingThreadPool(astCreators.map(x => () => remapPaths(x.summarize(asExternal = true))).iterator) @@ -225,6 +239,8 @@ class DependencyDownloader(cpg: Cpg) { .getOrElse(RubyProgramSummary()) librarySummaries + } finally { + tmpDir.delete(swallowIOExceptions = true) } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/FreshNameGenerator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/FreshNameGenerator.scala index a7e77e248d8a..fc87a4a4c89e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/FreshNameGenerator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/utils/FreshNameGenerator.scala @@ -7,4 +7,8 @@ class FreshNameGenerator[T](template: Int => T) { counter += 1 name } + + def current: T = { + template(counter - 1) + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/ConfigTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/ConfigTests.scala index e327e5a5e14f..3346fd6607ae 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/ConfigTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/ConfigTests.scala @@ -18,7 +18,9 @@ class ConfigTests extends AnyWordSpec with Matchers with Inside { "--output", "OUTPUT", "--exclude", - "1EXCLUDE_FILE,2EXCLUDE_FILE", + "1EXCLUDE_FILE", + "--exclude", + "2EXCLUDE_FILE", "--exclude-regex", "EXCLUDE_REGEX" // Frontend-specific args diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/IgnoreRegexTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/IgnoreRegexTest.scala new file mode 100644 index 000000000000..c74f206aeedf --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/config/IgnoreRegexTest.scala @@ -0,0 +1,17 @@ +package io.joern.rubysrc2cpg.config + +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class IgnoreRegexTest extends RubyCode2CpgFixture { + "File matching default ignore regex should be skipped" in { + val cpg = code( + """ + |puts "test file" + |""".stripMargin, + "tmpdir/db/migrate/test0.rb" + ) + + cpg.file.map(_.name).contains("tmpdir/db/migrate/test0.rb") shouldBe false + } +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ArrayTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ArrayTests.scala index 71a6d0b70ba2..20ac510988ac 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ArrayTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ArrayTests.scala @@ -120,7 +120,7 @@ class ArrayTests extends RubyCode2CpgFixture(withPostProcessing = true, withData val source = cpg.literal.code("b").l val sink = cpg.call.name("puts").argument(1).l val List(flow) = sink.reachableByFlows(source).map(flowToResultPairs).distinct.sortBy(_.length).l - flow shouldBe List(("%w[b c]", 2), ("a = %w[b c]", 2), ("puts a", 3)) + flow shouldBe List(("[0] = b", 2), ("", 2), ("a = %w[b c]", 2), ("puts a", 3)) } "flow through %i array" in { @@ -130,15 +130,12 @@ class ArrayTests extends RubyCode2CpgFixture(withPostProcessing = true, withData |puts a |""".stripMargin) - val source = cpg.literal.code("b").l + val source = cpg.literal.code(".*b.*").l val sink = cpg.call.name("puts").argument(1).l val List(flow) = sink.reachableByFlows(source).map(flowToResultPairs).distinct.sortBy(_.length).l flow shouldBe List( - ( - """|%i[b - | c]""".stripMargin, - 2 - ), + ("[0] = :b", 2), + ("", 2), ( """|a = %i[b | c]""".stripMargin, diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala index d54cff93bd70..df3264db6379 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala @@ -287,8 +287,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataF | |x = 1 |foo = Foo.new - |y = foo - | .bar(1) + |y = foo.bar(1) |puts y |""".stripMargin) @@ -296,25 +295,13 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataF val sink = cpg.call.name("puts").argument(1).l val List(flow) = sink.reachableByFlows(src).map(flowToResultPairs).distinct.sortBy(_.length).l flow shouldBe List( - ( - """|foo - | .bar(1)""".stripMargin, - 11 - ), + ("foo.bar(1)", 10), ("bar(self, x)", 3), ("return x", 4), ("RET", 3), - ( - """|foo - | .bar(1)""".stripMargin, - 10 - ), - ( - """|y = foo - | .bar(1)""".stripMargin, - 10 - ), - ("puts y", 12) + ("foo.bar(1)", 10), + ("y = foo.bar(1)", 10), + ("puts y", 11) ) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/MethodTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/MethodTests.scala index bcdd6e343564..cb2aacc067c9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/MethodTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/MethodTests.scala @@ -5,8 +5,8 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.semanticcpg.language.* class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataFlow = true) { - // Works in deprecated - "Data flow through class method" ignore { + + "Data flow through class method" in { val cpg = code(""" |class MyClass | def print(text) @@ -21,8 +21,8 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat |""".stripMargin) val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 + val sink = cpg.call.name("puts").argument.l + sink.reachableByFlows(src).size shouldBe 4 } "Data flow through do-while loop" in { @@ -57,7 +57,7 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat } // Works in deprecated - "Data flow through blockExprAssocTypeArguments" ignore { + "Data flow through blockExprAssocTypeArguments" in { val cpg = code(""" |def foo(*args) |puts args @@ -69,12 +69,11 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat |""".stripMargin) val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 + val sink = cpg.call.name("puts").argument.l + sink.reachableByFlows(source).size shouldBe 4 } - // Works in deprecated - Unsupported element type SplattingArgumentArgumentList - "Data flow through blockSplattingTypeArguments" ignore { + "Data flow through blockSplattingTypeArguments" in { val cpg = code(""" |def foo(arg) |puts arg @@ -86,12 +85,11 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat |""".stripMargin) val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 + val sink = cpg.call.name("puts").argument.l + sink.reachableByFlows(source).size shouldBe 4 } - // Works in deprecated - "Data flow through blockSplattingExprAssocTypeArguments without block" ignore { + "Data flow through blockSplattingExprAssocTypeArguments without block" in { val cpg = code(""" |def foo(*arg) |puts arg @@ -103,12 +101,11 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat |""".stripMargin) val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 + val sink = cpg.call.name("puts").argument.l + sink.reachableByFlows(source).size shouldBe 4 } - // Works in deprecated - Unsupported element type SplattingArgumentArgumentList - "Data flow through blockSplattingTypeArguments without block" ignore { + "Data flow through blockSplattingTypeArguments without block" in { val cpg = code(""" |def foo (blockArg,&block) |block.call(blockArg) @@ -124,7 +121,7 @@ class MethodTests extends RubyCode2CpgFixture(withPostProcessing = true, withDat |""".stripMargin) val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l + val sink = cpg.call.name("puts").argument.l sink.reachableByFlows(source).size shouldBe 2 } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ProcParameterAndYieldTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ProcParameterAndYieldTests.scala index 04ca9426a463..3aeece7aed74 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ProcParameterAndYieldTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/ProcParameterAndYieldTests.scala @@ -95,7 +95,7 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture(withPostProcessing sink2.reachableByFlows(src2).size shouldBe 2 } - "Data flow through invocationWithBlockOnlyPrimary usage" in { + "Data flow through invocationWithBlockOnlyPrimary usage" ignore { val cpg = code(""" |def hello(&block) | block.call @@ -110,7 +110,7 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture(withPostProcessing sink.reachableByFlows(source).size shouldBe 1 } - "Data flow through invocationWithBlockOnlyPrimary and method name starting with capital usage" in { + "Data flow through invocationWithBlockOnlyPrimary and method name starting with capital usage" ignore { val cpg = code(""" |def Hello(&block) | block.call @@ -126,7 +126,7 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture(withPostProcessing } // Works in deprecated - "Data flow for yield block specified along with the call" in { + "Data flow for yield block specified along with the call" ignore { val cpg = code(""" |x=10 |def foo(x) @@ -168,7 +168,7 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture(withPostProcessing sink.reachableByFlows(source).size shouldBe 2 } - "flow through a proc definition with non-empty block and zero parameters" in { + "flow through a proc definition with non-empty block and zero parameters" ignore { val cpg = code(""" |x=10 |y = x diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/SingleAssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/SingleAssignmentTests.scala index 7f20321567c8..8ee26a108469 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/SingleAssignmentTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/SingleAssignmentTests.scala @@ -15,13 +15,13 @@ class SingleAssignmentTests extends RubyCode2CpgFixture(withPostProcessing = tru |""".stripMargin) val source = cpg.literal.l val sink = cpg.method.name("puts").callIn.argument.l - val flows = sink.reachableByFlows(source).map(flowToResultPairs).distinct.sortBy(_.length).l - val List(flow1, flow2, flow3, flow4, flow5) = flows - flow1 shouldBe List(("y = 1", 2), ("puts y", 3)) - flow2 shouldBe List(("y = 1", 2), ("x = y = 1", 2), ("puts x", 4)) - flow3 shouldBe List(("y = 1", 2), ("puts y", 3), ("puts x", 4)) - flow4 shouldBe List(("y = 1", 2), ("x = y = 1", 2), ("z = x = y = 1", 2), ("puts z", 5)) - flow5 shouldBe List(("y = 1", 2), ("x = y = 1", 2), ("puts x", 4), ("puts z", 5)) + val flows = sink.reachableByFlows(source).map(flowToResultPairs).distinct.l + flows.size shouldBe 5 + flows should contain(List(("y = 1", 2), ("puts y", 3))) + flows should contain(List(("y = 1", 2), ("x = y = 1", 2), ("puts x", 4))) + flows should contain(List(("y = 1", 2), ("puts y", 3), ("puts x", 4))) + flows should contain(List(("y = 1", 2), ("x = y = 1", 2), ("z = x = y = 1", 2), ("puts z", 5))) + flows should contain(List(("y = 1", 2), ("x = y = 1", 2), ("puts x", 4), ("puts z", 5))) } "flow through expressions" in { @@ -47,6 +47,21 @@ class SingleAssignmentTests extends RubyCode2CpgFixture(withPostProcessing = tru sink.reachableByFlows(src).l.size shouldBe 2 } + "flow through **=" in { + val cpg = code(""" + |x = 5 + |call1(x**=2) + |call2(x) + |""".stripMargin) + + val source = cpg.literal("2").l + val call1 = cpg.call("call1") + val call2 = cpg.call("call2") + + call1.reachableBy(source).l shouldBe source + call2.reachableBy(source).l shouldBe source + } + "Data flow through grouping expression" in { val cpg = code(""" |x = 0 diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/dataflow/DataFlowTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/dataflow/DataFlowTests.scala deleted file mode 100644 index d95eaf938596..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/dataflow/DataFlowTests.scala +++ /dev/null @@ -1,2633 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.dataflow - -import io.joern.dataflowengineoss.language.* -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class DataFlowTests - extends RubyCode2CpgFixture(withPostProcessing = true, withDataFlow = true, useDeprecatedFrontend = true) { - - "Data flow through if-elseif-else" should { - val cpg = code(""" - |x = 2 - |a = x - |b = 0 - | - |if a > 2 - | b = a + 3 - |elsif a > 4 - | b = a + 5 - |elsif a > 8 - | b = a + 5 - |else - | b = a + 9 - |end - | - |puts(b) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Flow via call" should { - val cpg = code(""" - |def print(content) - |puts content - |end - | - |def main - |n = 1 - |print( n ) - |end - |""".stripMargin) - - "be found" in { - implicit val resolver: ICallResolver = NoResolve - val src = cpg.identifier.name("n").where(_.inCall.name("print")).l - val sink = cpg.method.name("puts").callIn.argument(1).l - sink.reachableByFlows(src).size shouldBe 1 - } - } - - "Explicit return via call with initialization" should { - val cpg = code(""" - |def add(p) - |q = 5 - |q = p - |return q - |end - | - |n = 1 - |ret = add(n) - |puts ret - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("n").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Implicit return via call with initialization" should { - val cpg = code(""" - |def add(p) - |q = 5 - |q = p - |q - |end - | - |n = 1 - |ret = add(n) - |puts ret - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("n").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Implicit return in if-else block" should { - val cpg = code(""" - |def foo(arg) - |if arg > 1 - | arg + 1 - |else - | arg + 10 - |end - |end - | - |x = 1 - |y = foo x - |puts y - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Implicit return in if-else block and underlying function call" should { - val cpg = code(""" - |def add(arg) - |arg + 100 - |end - | - |def foo(arg) - |if arg > 1 - | add(arg) - |else - | add(arg) - |end - |end - | - |x = 1 - |y = foo x - |puts y - | - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Return via call w/o initialization" should { - val cpg = code(""" - |def add(p) - |q = p - |return q - |end - | - |n = 1 - |ret = add(n) - |puts ret - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("n").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow in a while loop" should { - val cpg = code(""" - |i = 0 - |num = 5 - | - |while i < num do - | num = i + 3 - |end - |puts num - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("i").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 3 - } - } - - "Data flow in a while modifier" should { - val cpg = code(""" - |i = 0 - |num = 5 - |begin - | num = i + 3 - |end while i < num - |puts num - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("i").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 3 - } - } - - "Data flow through expressions" should { - val cpg = code(""" - |a = 1 - |b = a+3 - |c = 2 + b%6 - |d = c + b & !c + -b - |e = c/d + b || d - |f = c - d & ~e - |g = f-c%d - +d - |h = g**5 << b*g - |i = b && c || e > g - |j = b>c ? (e+-6) : (f +5) - |k = i..h - |l = j...g - | - |puts l - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("a").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through multiple assignments" should { - val cpg = code(""" - |x = 1 - |y = 2 - |c, d = x, y - |puts c - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through multiple assignments with grouping" should { - val cpg = code(""" - |x = 1 - |y = 2 - |(c, d) = x, y - |puts c - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through multiple assignments with multi level grouping" ignore { - val cpg = code(""" - |x = 1 - |y = 2 - |z = 3 - |a,(b,c) = z,y,x - |puts a - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - "Data flow through multiple assignments with grouping and method in RHS" should { - val cpg = code(""" - |def foo() - |x = 1 - |return x - |end - | - |b = 2 - |(c, d) = foo, b - |puts c - | - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through single LHS and splatting RHS" should { - val cpg = code(""" - |x=1 - |y=*x - |puts y - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through class method" should { - val cpg = code(""" - |class MyClass - | def print(text) - | puts text - | end - |end - | - | - |x = "some text" - |inst = MyClass.new - |inst.print(x) - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through class member" should { - val cpg = code(""" - |class MyClass - | @instanceVariable - | - | def initialize(value) - | @instanceVariable = value - | end - | - | def getValue() - | @instanceVariable - | end - |end - | - |x = 12345 - |inst = MyClass.new(x) - |y = inst.getValue - |puts y - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through module method" should { - val cpg = code(""" - |module MyModule - | def MyModule.print(text) - | puts text - | end - |end - | - |x = "some text" - | - |MyModule::print(x) - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 2 - } - } - - "Data flow through yield with argument having parenthesis" should { - val cpg = code(""" - |def yield_with_arguments - | a = "something" - | yield(a) - |end - | - |yield_with_arguments { |arg| puts "Argument is #{arg}" } - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("a").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 2 - } - } - - "Data flow through yield with argument without parenthesis and multiple yield blocks" should { - val cpg = code(""" - |def yield_with_arguments - | x = "something" - | y = "something_else" - | yield(x,y) - |end - | - |yield_with_arguments { |arg1, arg2| puts "Yield block 1 #{arg1} and #{arg2}" } - |yield_with_arguments { |arg1, arg2| puts "Yield block 2 #{arg2} and #{arg1}" } - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 4 - } - } - - "Data flow through yield without argument" should { - val cpg = code(""" - |x = 1 - |def yield_method - | yield - |end - |yield_method { puts x } - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 1 - } - } - - "Data flow coming out of yield without argument" should { - val cpg = code(""" - |def foo - | x=10 - | z = yield - | puts z - |end - | - |x = 100 - |foo{ x + 10 } - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 1 - } - } - // TODO: - "Data flow through yield with argument and multiple yield blocks" ignore { - val cpg = code(""" - |def yield_with_arguments - | x = "something" - | y = "something_else" - | yield(x) - | yield(y) - |end - | - |yield_with_arguments { |arg| puts "Yield block 1 #{arg}" } - |yield_with_arguments { |arg| puts "Yield block 2 #{arg}" } - |""".stripMargin) - - "be found" in { - val src1 = cpg.identifier.name("x").l - val sink1 = cpg.call.name("puts").l - sink1.reachableByFlows(src1).size shouldBe 2 - - val src2 = cpg.identifier.name("y").l - val sink2 = cpg.call.name("puts").l - sink2.reachableByFlows(src2).size shouldBe 2 - } - } - - "Data flow through a until loop" should { - val cpg = code(""" - |i = 0 - |num = 5 - | - |until i < num - | num = i + 3 - |end - |puts num - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("i").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 3 - } - } - - "Data flow in through until modifier" should { - val cpg = code(""" - |i = 0 - |num = 5 - |begin - | num = i + 3 - |end until i < num - |puts num - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("i").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).l.size shouldBe 3 - } - } - - "Data flow through unless-else" should { - val cpg = code(""" - |x = 2 - |a = x - |b = 0 - | - |unless a > 2 - | b = a + 3 - |else - | b = a + 9 - |end - | - |puts(b) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through case statement" should { - val cpg = code(""" - |x = 2 - |b = x - | - |case b - |when 1 - | puts b - |when 2 - | puts b - |when 3 - | puts b - |else - | puts b - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 8 - } - } - - "Data flow through do-while loop" should { - val cpg = code(""" - |x = 0 - |num = -1 - |loop do - | num = x + 1 - | x = x + 1 - | if x > 10 - | break - | end - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for loop" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | y = x + i - | num = y*i - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for loop simple" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | num = x - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and next AFTER statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | num = x - | next if i % 2 == 0 - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and next BEFORE statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | next if i % 2 == 0 - | num = x - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and redo AFTER statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | num = x - | redo if i % 2 == 0 - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and redo BEFORE statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | redo if i % 2 == 0 - | num = x - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and retry AFTER statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | num = x - | retry if i % 2 == 0 - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through for and retry BEFORE statement" should { - val cpg = code(""" - |x = 0 - |arr = [1,2,3,4,5] - |num = 0 - |for i in arr do - | retry if i % 2 == 0 - | num = x - |end - |puts num - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through grouping expression" should { - val cpg = code(""" - |x = 0 - |y = (x==0) - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through variable assigned a scoped constant" should { - val cpg = code(""" - |MyConst = 10 - |x = ::MyConst - |puts x - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through variable assigned a chained scoped constant" should { - val cpg = code(""" - |MyConst = 10 - |x = ::MyConst - |puts x - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through array constructor expressionsOnlyIndexingArguments" should { - val cpg = code(""" - |x = 1 - |array = [x,2] - |puts x - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 3 - } - } - - "Data flow through array constructor splattingOnlyIndexingArguments" should { - val cpg = code(""" - |def foo(*splat_args) - |array = [*splat_args] - |puts array - |end - | - |x = 1 - |y = 2 - |y = foo(x,y) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through array constructor expressionsAndSplattingIndexingArguments" should { - val cpg = code(""" - |def foo(*splat_args) - |array = [1,2,*splat_args] - |puts array - |end - | - |x = 3 - |foo(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through array constructor associationsOnlyIndexingArguments" should { - val cpg = code(""" - |def foo(arg) - |array = [1 => arg, 2 => arg] - |puts array - |end - | - |x = 3 - |foo(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through array constructor commandOnlyIndexingArguments" should { - val cpg = code(""" - |def increment(arg) - |return arg + 1 - |end - | - |x = 1 - |array = [ increment(x), increment(x+1)] - |puts array - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 3 - } - } - - "Data flow through hash constructor" should { - val cpg = code(""" - |def foo(arg) - |hash = {1 => arg, 2 => arg} - |puts hash - |end - | - |x = 3 - |foo(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through string interpolation" should { - val cpg = code(""" - |x = 1 - |str = "The source is #{x}" - |puts str - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through indexingExpressionPrimary" should { - val cpg = code(""" - |x = [1,2,3] - |y = x[0] - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through methodOnlyIdentifier usage" should { - val cpg = code(""" - |x = 1 - |y = SomeConstant! + x - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through chainedInvocationPrimary usage" should { - val cpg = code(""" - |x = 1 - | - |[x, x+1].each do |number| - | puts "#{number} was passed to the block" - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - // TODO: - "Data flow coming out of chainedInvocationPrimary usage" ignore { - val cpg = code(""" - |x = 1 - |y = 10 - |[x, x+1].each do |number| - | y += x - |end - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through chainedInvocationPrimary without arguments to block usage" should { - val cpg = code(""" - |x = 1 - | - |[1,2,3].each do - | puts "Right here #{x}" - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 1 - } - } - - "Data flow through invocationWithBlockOnlyPrimary usage" should { - val cpg = code(""" - |def hello(&block) - | block.call - |end - | - |x = "hello" - |hello { puts x } - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow through invocationWithBlockOnlyPrimary and method name starting with capital usage" should { - val cpg = code(""" - |def Hello(&block) - | block.call - |end - |x = "hello" - |Hello = "this should not be used" - |Hello { puts x } - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).l.size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in begin" should { - val cpg = code(""" - |x = 1 - |begin - | puts x - |rescue SomeException - | puts "SomeException occurred" - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts "Catch-all block" - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in else" should { - val cpg = code(""" - |x = 1 - |begin - | puts "In begin" - |rescue SomeException - | puts "SomeException occurred" - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts "Catch-all block" - |else - | puts x - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in rescue" should { - val cpg = code(""" - |x = 1 - |begin - | puts "in begin" - |rescue SomeException - | puts x - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts "Catch-all block" - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in rescue with exception var" should { - val cpg = code(""" - |begin - | puts "in begin" - |rescue SomeException - | puts "SomeException occurred" - |rescue => x - | y = x - | puts "Caught exception in variable #{y}" - |rescue - | puts "Catch-all block" - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - "Data flow for begin/rescue with sink in catch-all rescue" should { - val cpg = code(""" - |x = 1 - |begin - | puts "in begin" - |rescue SomeException - | puts "SomeException occurred" - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts x - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in ensure" should { - val cpg = code(""" - |x = 1 - |begin - | puts "in begin" - |rescue SomeException - | puts "SomeException occurred" - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts "In rescue all" - |ensure - | puts x - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with data flow through the exception" should { - val cpg = code(""" - |x = "Exception message: " - |begin - |1/0 - |rescue ZeroDivisionError => e - | y = x + e.message - | puts y - |end - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with data flow through block with multiple exceptions being caught" should { - val cpg = code(""" - |x = 1 - |y = 10 - |begin - |1/0 - |rescue SystemCallError, ZeroDivisionError - | y = x + 100 - |end - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in function without begin" should { - val cpg = code(""" - |def foo(arg) - | puts "in begin" - |rescue SomeException - | return arg - |rescue => exvar - | puts "Caught exception in variable #{exvar}" - |rescue - | puts "Catch-all block" - |end - | - |x = 1 - |y = foo x - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in function without begin and sink in rescue with exception" should { - val cpg = code(""" - |def foo(arg) - | puts "in begin" - |rescue SomeException - | puts "SomeException occurred #{arg}" - |rescue => exvar - | puts "Caught exception in variable #{exvar}" - |rescue - | puts "Catch-all block" - |end - | - |x = 1 - |foo x - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in function without begin and sink in catch-call rescue" should { - val cpg = code(""" - |def foo(arg) - | puts "in begin" - | raise "This is an exception" - |rescue - | puts "Catch-all block. Arg is #{arg}" - |end - | - |x = 1 - |foo x - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in function without begin with return from begin" should { - val cpg = code(""" - |def foo(arg) - | puts "in begin" - | return arg - |rescue SomeException - | puts "Caught SomeException" - |rescue => exvar - | puts "Caught exception in variable #{exvar}" - |rescue - | puts "Catch-all block" - |end - | - |x = 1 - |y = foo x - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for begin/rescue with sink in function within do block" ignore { - val cpg = code(""" - |def foo(arg) - | puts "in begin" - | arg do |y| - | return y - |rescue SomeException - | puts "Caught SomeException" - |rescue => exvar - | puts "Caught exception in variable #{exvar}" - |rescue - | puts "Catch-all block" - |end - | - |x = 1 - |z = foo x - |puts z - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").lineNumber(8).l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - "Data flow through array assignments" should { - val cpg = code(""" - |x = 10 - |array = [0, 1] - |array[0] = x - |puts array - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through chained scoped constant reference" should { - val cpg = code(""" - |module SomeModule - |SomeConstant = 1 - |end - | - |x = 1 - |y = SomeModule::SomeConstant * x - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through scopedConstantAccessSingleLeftHandSide" should { - val cpg = code(""" - |SomeConstant = 1 - | - |x = 1 - |::SomeConstant = x - |y = ::SomeConstant + 10 - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through xdotySingleLeftHandSide through a constant on left of the ::" should { - val cpg = code(""" - |module SomeModule - |SomeConstant = 100 - |end - | - |x = 2 - |SomeModule::SomeConstant = x - |y = SomeModule::SomeConstant - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - // TODO: - "Data flow through xdotySingleLeftHandSide through a local on left of the ::" ignore { - val cpg = code(""" - |module SomeModule - |SomeConstant = 100 - |end - | - |x = 2 - |local = SomeModule - |local::SomeConstant = x - |y = SomeModule::SomeConstant - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through packing left hand side through the first identifier" should { - val cpg = code(""" - |x = 1 - |p = 2 - |*y = x,p - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through packing left hand side through beyond the first identifier" should { - val cpg = code(""" - |x = 1 - |y = 2 - |z = 3 - |*a = z,y,x - |puts a - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through packing left hand side with unequal RHS" should { - val cpg = code(""" - |x = 1 - |y = 2 - |z = 3 - |p,*a = z,y,x - |puts a - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through single LHS and multiple RHS" should { - val cpg = code(""" - |x = 1 - |y = 2 - |z = 3 - |a = z,y,x - |puts a - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through argsAndDoBlockAndMethodIdCommandWithDoBlock" should { - val cpg = code(""" - |def foo (blockArg,&block) - |block.call(blockArg) - |end - | - |x = 10 - |foo :a_symbol do |arg| - | y = x + arg.length - | puts y - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through primaryMethodArgsDoBlockCommandWithDoBlock" should { - val cpg = code(""" - |module FooModule - |def foo (blockArg,&block) - |block.call(blockArg) - |end - |end - | - |x = 10 - |FooModule.foo :a_symbol do |arg| - | y = x + arg.length - | puts y - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow with super usage" should { - val cpg = code(""" - |class BaseClass - | def doSomething(arg) - | return arg + 10 - | end - |end - | - |class DerivedClass < BaseClass - | def doSomething(arg) - | super(arg) - | end - |end - | - |x = 1 - |object = DerivedClass.new - |y = object.doSomething(x) - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through blockExprAssocTypeArguments" should { - val cpg = code(""" - |def foo(*args) - |puts args - |end - | - |x = "value1" - |foo(key1: x, key2: "value2") - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through blockSplattingTypeArguments" should { - val cpg = code(""" - |def foo(arg) - |puts arg - |end - | - |x = 1 - |foo(*x) - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through blockSplattingExprAssocTypeArguments without block" should { - val cpg = code(""" - |def foo(*arg) - |puts arg - |end - | - |x = 1 - |foo( x+1, key1: x*2, key2: x*3 ) - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through blockSplattingTypeArguments without block" should { - val cpg = code(""" - |def foo (blockArg,&block) - |block.call(blockArg) - |end - | - |x = 10 - |foo(*x do |arg| - | y = x + arg - | puts y - |end - |) - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - // TODO: - "Data flow through blockExprAssocTypeArguments with block argument in the wrapper function" ignore { - val cpg = code(""" - |def foo (blockArg,&block) - |block.call(blockArg) - |end - | - |def foo_wrap (blockArg,&block) - |foo(blockArg,&block) - |end - | - | - |x = 10 - |foo_wrap x do |arg| - | y = 100 + arg - | puts y - |end - | - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through grouping expression with negation" should { - val cpg = code(""" - |def foo(arg) - |return arg - |end - | - |x = false - |y = !(foo x) - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through break with args" should { - val cpg = code(""" - |x = 1 - |arr = [x, 2, 3] - |y = arr.each do |num| - | break num if num < 2 - | puts num - |end - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 4 - } - } - - // TODO: - "Data flow through next with args" ignore { - val cpg = code(""" - |x = 10 - |a = [1, 2, 3] - |y = a.map do |num| - | next x if num.even? - | num - |end - | - |puts y - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - // TODO: - "Data flow through a global variable" ignore { - val cpg = code(""" - |def foo(arg) - | loop do - | arg += 1 - | if arg > 3 - | $y = arg - | return - | end - | end - |end - | - |x = 1 - |foo x - |puts $y - | - | - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow using a keyword" should { - val cpg = code(""" - |class MyClass - |end - | - |x = MyClass.new - |y = x.class - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through variable params" should { - val cpg = code(""" - |def foo(*args) - | return args - |end - | - |x = 1 - |y = foo(x, "another param") - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through optional params" should { - val cpg = code(""" - |def foo(arg=10) - | return arg + 10 - |end - | - |x = 1 - |y = foo(x) - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow across files" should { - val cpg = code( - """ - |def my_func(x) - | puts x - |end - |""".stripMargin, - "foo.rb" - ) - .moreCode( - """ - |require_relative 'foo.rb' - |x = 1 - |my_func(x) - |""".stripMargin, - "bar.rb" - ) - - "be found in" in { - val source = cpg.literal.code("1").l - val sink = cpg.call.name("puts").argument(1).l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - "Across the file data flow test" should { - val cpg = code( - """ - |def foo(arg) - | puts arg - | loop do - | arg += 1 - | if arg > 3 - | puts arg - | return - | end - | end - |end - |""".stripMargin, - "foo.rb" - ) - .moreCode( - """ - |require_relative 'foo.rb' - |x = 1 - |foo x - |""".stripMargin, - "bar.rb" - ) - - "be found in" in { - val source = cpg.literal.code("1").l - val sink = cpg.call.name("puts").argument(1).lineNumber(3).l - sink.reachableByFlows(source).size shouldBe 1 - val src = cpg.identifier("x").lineNumber(3).l - sink.reachableByFlows(src).size shouldBe 1 - } - - // TODO: Need to be fixed. - "be found for sink in nested block" ignore { - val src = cpg.identifier("x").lineNumber(3).l - val sink = cpg.call.name("puts").argument(1).lineNumber(7).l - sink.reachableByFlows(src).size shouldBe 1 - } - } - - "Data flows for pseudo variable identifiers" should { - "Data flow for __LINE__ variable identifier" should { - val cpg = code(""" - |x=1 - |a=x+__LINE__ - |puts a - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - } - - "Data flow for chained command with do-block with parentheses" should { - val cpg = code(""" - |def foo() - | yield if block_given? - |end - | - |y = foo do - | x = 1 - | [x+1,x+2] - |end.sum(10) - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for chained command with do-block without parentheses" should { - val cpg = code(""" - |def foo() - | yield if block_given? - |end - | - |y = foo do - | x = 1 - | [x+1,x+2] - |end.sum 10 - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow for yield block specified alongwith the call" should { - val cpg = code(""" - |x=10 - |def foo(x) - | a = yield - | puts a - |end - | - |foo(x) { - | x + 2 - |} - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 1 - /* - * TODO the flow count shows 1 since the origin is considered as x + 2 - * The actual origin is x=10. However, this is not considered since there is - * no REACHING_DEF edge from the x of 'x=10' to the x of 'x + 2'. - * There are already other disabled data flow test cases for this problem. Once solved, it should - * be possible to set the required count to 2 - */ - - } - } - - "Data flows through range operators" should { - val cpg = code(""" - |x = 10 - |y=0 - |for i in 1...10 do - | x += i - | if (x > 10) - | y = x - | end - |end - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - "Data flow through unless modifier" should { - val cpg = code(""" - |x = 1 - | - |x += 2 unless x.zero? - | puts(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - "Data flow through invocation or command with EMARK" should { - val cpg = code(""" - |x=12 - |def woo(x) - | return x == 10 - |end - | - |if !woo x - | puts x - |else - | puts "No" - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - // TODO: - "Data flow through overloaded operator method" ignore { - val cpg = code(""" - |class Foo - | @@x = 1 - | def +(y) - | @@x + y - | end - |end - | - |y = Foo.new + 1 - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.member.name("@@x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - // TODO: - "Data flow through assignment-like method identifier" ignore { - val cpg = code(""" - |class Foo - | @@x = 1 - | def CONST=(y) - | return @@x == y - | end - |end - |puts Foo::CONST= 2 - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.member.name("@@x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - // TODO: - "Data flow through a when argument context" ignore { - val cpg = code(""" - |x = 10 - | - |case x - | - |when 1..5 - | y = x - |when 5..10 - | z = x - |when 10..15 - | w = x - |else - | _p = x - |end - | - |puts _p - |puts w - |puts y - |puts z - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - "Data flow through ensureClause" should { - val cpg = code(""" - |begin - | x = File.open("myFile.txt", "r") - | x << "#{content} \n" - |rescue - | x = "pqr" - |ensure - | x = "abc" - | y = x - |end - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 // flow through the rescue is not a flow - } - } - - "Data flow through begin-else" should { - val cpg = code(""" - |begin - | x = File.open("myFile.txt", "r") - | x << "#{content} \n" - |rescue - | x = "pqr" - |else - | y = x - |ensure - | x = "abc" - |end - | - |puts y - |""".stripMargin).moreCode( - """ - |My file - |""".stripMargin, - "myFile.txt" - ) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - "Data flow through block argument context" should { - val cpg = code(""" - |x=10 - |y=0 - |def foo(n, &block) - | woo(n, &block) - |end - | - |def woo(n, &block) - | n.times {yield} - |end - | - |foo(5) { - | y = x - |} - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Data flow through block splatting type arguments context" should { - val cpg = code(""" - |x=10 - |y=0 - |def foo(*n, &block) - | woo(*n, &block) - |end - | - |def woo(n, &block) - | n.times {yield} - |end - | - |foo(5) { - | y = x - |} - | - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Flow through tainted object" should { - val cpg = code(""" - |def put_req(api_endpoint, params) - | puts "Hitting " + api_endpoint + " with params: " + params - |end - |class TestClient - | def get_event_data(accountId) - | payload = accountId - | r = put_req( - | "https://localhost:8080/v3/users/me/", - | params=payload - | ) - | end - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("accountId").l - val sink = cpg.call.name("put_req").l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - // TODO: - "Flow for a global variable" ignore { - val cpg = code(""" - |$person_height = 6 - |class Person - | def height_in_cm - | puts $person_height * 30 - | end - |end - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("$person_height").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "Flow for nested puts calls" should { - val cpg = code(""" - |x=10 - |def put_name(x) - | puts x - |end - |def nested_put(x) - | put_name(x) - |end - |def double_nested_put(x) - | nested_put(x) - |end - |double_nested_put(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 5 - } - } - - "Data flow through a keyword? named method usage" should { - val cpg = code(""" - |x = 1 - |y = x.nil? - |puts y - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 2 - } - } - - "Data flow through a keyword inside a association" should { - val cpg = code(""" - |def foo(arg) - |puts arg - |end - | - |x = 1 - |foo if: x.nil? - |""".stripMargin) - - "be found" in { - val src = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(src).size shouldBe 2 - } - } - - "Data flow through a regex interpolation" should { - val cpg = code(s""" - |x="abc" - |y=/x#{x}b/ - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a regex interpolation with multiple expressions" should { - val cpg = code(""" - |x="abc" - |y=/x#{x}b#{x+'z'}b{x+'y'+'z'}w/ - |puts y - |""".stripMargin) - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 3 - } - } - - "flow through a proc definition using Proc.new and flow originating within the proc" should { - val cpg = code(""" - |y = Proc.new { - |x=1 - |x - |} - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a proc definition with non-empty block and zero parameters" ignore { - val cpg = code(""" - |x=10 - |y = x - |-> { - |puts y - |}.call - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a proc definition with non-empty block and non-zero parameters" should { - val cpg = code(""" - |x=10 - |-> (arg){ - |puts arg - |}.call(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a method call with safe navigation operator with parantheses" should { - val cpg = code(""" - |class Foo - | def bar(arg) - | return arg - | end - |end - |x=1 - |foo = Foo.new - |y = foo&.bar(x) - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a method call with safe navigation operator without parantheses" should { - val cpg = code(""" - |class Foo - | def bar(arg) - | return arg - | end - |end - |x=1 - |foo = Foo.new - |y = foo&.bar x - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through a method call present in next line, with the second line starting with `.`" should { - val cpg = code(""" - |class Foo - | def bar(x) - | return x - | end - |end - | - |x = 1 - |foo = Foo.new - |y = foo - | .bar(1) - |puts y - |""".stripMargin) - - "find flow to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - "flow through a method call present in next line, with the first line ending with `.`" should { - val cpg = code(""" - |class Foo - | def bar(x) - | return x - | end - |end - | - |x = 1 - |foo = Foo.new - |y = foo. - | bar(1) - |puts y - |""".stripMargin) - - "find flow to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 1 - } - } - - "flow through statement when regular expression literal passed after `when`" should { - val cpg = code(""" - |x = 2 - |a = 2 - | - |case a - | when /^ch/ - | b = x - | puts b - |end - |""".stripMargin) - - "find flows to the sink" in { - - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through interpolated double-quoted string literal " should { - val cpg = code(""" - |x = "foo" - |y = :"bar #{x}" - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through conditional return statement" should { - val cpg = code(""" - |class Foo - | def bar(value) - | j = 0 - | return(value) unless j == 0 - | end - |end - | - |x = 10 - |foo = Foo.new - |y = foo.bar(x) - |puts y - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through statement with ternary operator with multiple line" in { - val cpg = code(""" - |x = 2 - |y = 3 - |z = 4 - | - |w = x == 2 ? - | y - | : z - |puts y - |""".stripMargin) - - val source = cpg.identifier.name("y").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through endless method" in { - val cpg = code(""" - |def multiply(a,b) = a*b - |x = 10 - |y = multiply(3,x) - |puts y - |""".stripMargin) - - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through symbol literal defined using \\:" should { - val cpg = code(""" - |def foo(arg) - |hash = {:y => arg} - |puts hash - |end - | - |x = 3 - |foo(x) - |""".stripMargin) - - "find flows to the sink" in { - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - } - - "flow through %w array" in { - val cpg = code(""" - |a = %w[b c] - |puts a - |""".stripMargin) - - val source = cpg.literal.code("b").l - val sink = cpg.call.name("puts").l - val List(flow) = sink.reachableByFlows(source).map(flowToResultPairs).distinct.sortBy(_.length).l - flow shouldBe List(("[b, c]", 2), ("[b, c]", -1), ("a = %w[b c]", 2), ("puts a", 3)) - } - - "flow through hash containing splatting literal" in { - val cpg = code(""" - |x={:y=>1} - |z = { - |**x - |} - |puts z - |""".stripMargin) - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "dataflow in method defined under class << self block" ignore { - // Marked it ignored as flow from identifier "firstName" to call "puts" is missing - val cpg = code(""" - class MyClass - | - | class << self - | def printPII - | firstName="somename" - | puts "log PII #{firstName}" - | end - | end - |end - | - |MyClass.printPII""".stripMargin) - - val source = cpg.identifier.name("firstName").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through association identifier" in { - val cpg = code(""" - |def foo(a:) - | puts a - |end - | - |x =1 - |foo(a:x) - |""".stripMargin) - - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through special prefix methods" in { - /* We only check private_class_method here. The mechanism is similar to others: - * attr_reader - * attr_writer - * attr_accessor - * remove_method - * public_class_method - * private - * protected - */ - val cpg = code(""" - |class Foo - | z = 1 - | private_class_method def self.bar(x) - | x - | end - | - | y = self.bar(z) - | puts y - |end - |""".stripMargin) - - val source = cpg.identifier.name("z").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through %i array" in { - val cpg = code(""" - |a = %i[b - | c] - |puts a - |""".stripMargin) - - val source = cpg.literal.code("b").l - val sink = cpg.call.name("puts").l - val List(flow) = sink.reachableByFlows(source).map(flowToResultPairs).distinct.sortBy(_.length).l - flow shouldBe List( - ("[b, c]", 2), - ("[b, c]", -1), - ( - """|a = %i[b - | c]""".stripMargin, - 2 - ), - ("puts a", 4) - ) - } - - "flow through array constructor using []" in { - val cpg = code(""" - |x=1 - |y=x - |z = Array[y,2] - |puts "#{z}" - |""".stripMargin) - - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } - - "flow through array constructor using [] and command in []" in { - val cpg = code(""" - |def foo(arg) - |return arg - |end - | - |x=1 - |y=x - |z = Array[foo y] - |puts "#{z}" - |""".stripMargin) - - val source = cpg.identifier.name("x").l - val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ArrayTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ArrayTests.scala deleted file mode 100644 index 7899dddd5287..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ArrayTests.scala +++ /dev/null @@ -1,319 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class ArrayTests extends RubyParserAbstractTest { - - "An empty array literal" should { - - "be parsed as a primary expression" when { - - "it uses the traditional [, ] delimiters" in { - val code = "[]" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | BracketedArrayConstructor - | [ - | ]""".stripMargin - } - - "it uses the %w[ ] delimiters" in { - val code = "%w[]" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w[ - | ]""".stripMargin - } - - "it uses the %W< > delimiters" in { - val code = "%W<>" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | ExpandedWordArrayConstructor - | %W< - | >""".stripMargin - } - - "it uses the %i[ ] delimiters" in { - val code = "%i[]" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedSymbolArrayConstructor - | %i[ - | ]""".stripMargin - } - - "it uses the %I{ } delimiters" in { - val code = "%I{}" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | ExpandedSymbolArrayConstructor - | %I{ - | }""".stripMargin - } - } - } - - "A non-empty word array literal" should { - - "be parsed as a primary expression" when { - - "it uses the %w[ ] delimiters" in { - val code = "%w[x y z]" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w[ - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | NonExpandedArrayElement - | z - | ]""".stripMargin - } - - "it uses the %w( ) delimiters" in { - val code = "%w(x y z)" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w( - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | NonExpandedArrayElement - | z - | )""".stripMargin - } - - "it uses the %w{ } delimiters" in { - val code = "%w{x y z}" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w{ - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | NonExpandedArrayElement - | z - | }""".stripMargin - } - - "it uses the %w< > delimiters" in { - val code = "%w" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w< - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | \ - | y - | >""".stripMargin - } - - "it uses the %w- - delimiters" in { - val code = "%w-x y z-" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w- - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | NonExpandedArrayElement - | z - | -""".stripMargin - } - - "it spans multiple lines" in { - val code = - """%w( - | bob - | cod - | dod - |)""".stripMargin - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedWordArrayConstructor - | %w( - | NonExpandedArrayElements - | NonExpandedArrayElement - | b - | o - | b - | NonExpandedArrayElement - | c - | o - | d - | NonExpandedArrayElement - | d - | o - | d - | )""".stripMargin - - } - - "it uses the %W( ) delimiters and contains a numeric interpolation" in { - val code = "%W(x#{1})" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | ExpandedWordArrayConstructor - | %W( - | ExpandedArrayElements - | ExpandedArrayElement - | x - | DelimitedArrayItemInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | )""".stripMargin - } - - "it spans multiple lines and contains a numeric interpolation" in { - val code = - """%W[ - | x#{0} - |]""".stripMargin - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | ExpandedWordArrayConstructor - | %W[ - | ExpandedArrayElements - | ExpandedArrayElement - | x - | DelimitedArrayItemInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | } - | ]""".stripMargin - } - } - } - - "A non-empty symbol array literal" should { - - "be parsed as a primary expression" when { - - "it uses the %i< > delimiters" in { - val code = "%i" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedSymbolArrayConstructor - | %i< - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | >""".stripMargin - } - - "it uses the %i{ } delimiters" in { - val code = "%i{x\\ y}" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedSymbolArrayConstructor - | %i{ - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | \ - | y - | }""".stripMargin - } - - "it uses the %i[ ] delimiters nestedly" in { - val code = "%i[x [y]]" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedSymbolArrayConstructor - | %i[ - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | [ - | y - | ] - | ]""".stripMargin - - } - - "it uses the %i( ) delimiters in a multi-line fashion" in { - val code = - """%i( - |x y - |z - |)""".stripMargin - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | NonExpandedSymbolArrayConstructor - | %i( - | NonExpandedArrayElements - | NonExpandedArrayElement - | x - | NonExpandedArrayElement - | y - | NonExpandedArrayElement - | z - | )""".stripMargin - } - - "it uses the %I( ) delimiters and contains a numeric interpolation" in { - val code = "%I(x#{0} x1)" - printAst(_.primary(), code) shouldEqual - """ArrayConstructorPrimary - | ExpandedSymbolArrayConstructor - | %I( - | ExpandedArrayElements - | ExpandedArrayElement - | x - | DelimitedArrayItemInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | } - | ExpandedArrayElement - | x - | 1 - | )""".stripMargin - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/AssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/AssignmentTests.scala deleted file mode 100644 index 4ba931015782..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/AssignmentTests.scala +++ /dev/null @@ -1,81 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class AssignmentTests extends RubyParserAbstractTest { - - "single assignment" should { - - "be parsed as a statement" when { - - "it contains no whitespace before `=`" in { - val code = "x=1" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | SingleAssignmentExpression - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | x - | = - | MultipleRightHandSide - | ExpressionOrCommands - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1""".stripMargin - } - } - } - - "multiple assignment" should { - - "be parsed as a statement" when { - "two identifiers are assigned an array containing two calls" in { - val code = "p, q = [foo(), bar()]" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultipleAssignmentExpression - | MultipleLeftHandSideAndpackingLeftHandSideMultipleLeftHandSide - | MultipleLeftHandSideItem - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | p - | , - | MultipleLeftHandSideItem - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | q - | = - | MultipleRightHandSide - | ExpressionOrCommands - | ExpressionExpressionOrCommand - | PrimaryExpression - | ArrayConstructorPrimary - | BracketedArrayConstructor - | [ - | ExpressionsOnlyIndexingArguments - | Expressions - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | BlankArgsArgumentsWithParentheses - | ( - | ) - | , - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | bar - | BlankArgsArgumentsWithParentheses - | ( - | ) - | ]""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginExpressionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginExpressionTests.scala deleted file mode 100644 index 17ec8f53514b..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginExpressionTests.scala +++ /dev/null @@ -1,58 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class BeginExpressionTests extends RubyParserAbstractTest { - - "A begin expression" should { - - "be parsed as a primary expression" when { - - "it contains a `rescue` clause with both exception class and exception variable" in { - val code = """begin - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin - - printAst(_.primary(), code) shouldBe - """BeginExpressionPrimary - | BeginExpression - | begin - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultiplicativeExpression - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | / - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | RescueClause - | rescue - | ExceptionClass - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | ZeroDivisionError - | ExceptionVariableAssignment - | => - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | e - | ThenClause - | CompoundStatement - | end""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginStatementTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginStatementTests.scala deleted file mode 100644 index 38591dc53fd6..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/BeginStatementTests.scala +++ /dev/null @@ -1,40 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class BeginStatementTests extends RubyParserAbstractTest { - - "BEGIN statement" should { - - "be parsed as a statement" when { - - "defined in a single line" in { - val code = "BEGIN { 1 }" - printAst(_.statement(), code) shouldEqual - """BeginStatement - | BEGIN - | { - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | }""".stripMargin - } - - "empty (single-line)" in { - val code = "BEGIN {}" - printAst(_.statement(), code) shouldEqual - """BeginStatement - | BEGIN - | { - | CompoundStatement - | }""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/CaseConditionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/CaseConditionTests.scala deleted file mode 100644 index 994946a46ad7..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/CaseConditionTests.scala +++ /dev/null @@ -1,194 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class CaseConditionTests extends RubyParserAbstractTest { - - "A case expression" should { - - "be parsed as a primary expression" when { - - "it contains just one `when` branch" in { - val code = - """case something - | when 1 - | puts 2 - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """CaseExpressionPrimary - | CaseExpression - | case - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | something - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ThenClause - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | end""".stripMargin - } - - "it contains both an empty `when` and `else` branch" in { - val code = - """case something - | when 1 - | else - | end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """CaseExpressionPrimary - | CaseExpression - | case - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | something - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ThenClause - | CompoundStatement - | ElseClause - | else - | CompoundStatement - | end""".stripMargin - } - - "it uses `then` as separator for `when`" in { - val code = - """case something - | when 1 then - | end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """CaseExpressionPrimary - | CaseExpression - | case - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | something - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ThenClause - | then - | CompoundStatement - | end""".stripMargin - } - - "it contains two single-line `when-then` branches" in { - val code = - """case x - | when 1 then 2 - | when 2 then 3 - | end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """CaseExpressionPrimary - | CaseExpression - | case - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ThenClause - | then - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | ThenClause - | then - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 3 - | end""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ClassDefinitionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ClassDefinitionTests.scala deleted file mode 100644 index 37ddb665eed7..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ClassDefinitionTests.scala +++ /dev/null @@ -1,112 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class ClassDefinitionTests extends RubyParserAbstractTest { - - "A one-line singleton class definition" should { - - "be parsed as a primary expression" when { - - "it contains no members" in { - val code = "class << self ; end" - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | << - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | PseudoVariableIdentifierVariableReference - | SelfPseudoVariableIdentifier - | self - | ; - | BodyStatement - | CompoundStatement - | end""".stripMargin - } - - "it contains a single numeric literal in its body" in { - val code = "class X 1 end" - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | ClassOrModuleReference - | X - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | end""".stripMargin - } - } - } - - "A multi-line singleton class definition" should { - - "be parsed as a primary expression" when { - - "it contains a single method definition" in { - val code = - """class << x - | def show; puts self; end - |end""".stripMargin - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | << - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | show - | MethodParameterPart - | BodyStatement - | CompoundStatement - | ; - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | PseudoVariableIdentifierVariableReference - | SelfPseudoVariableIdentifier - | self - | ; - | end - | end""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/EnsureClauseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/EnsureClauseTests.scala deleted file mode 100644 index 4610a18f801c..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/EnsureClauseTests.scala +++ /dev/null @@ -1,58 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class EnsureClauseTests extends RubyParserAbstractTest { - - "An ensure statement" should { - - "be parsed as a standalone statement" when { - - "in the immediate scope of a `def` block" in { - val code = - """def refund - | ensure - | redirect_to paddle_charge_path(@charge) - |end""".stripMargin - printAst(_.methodDefinition(), code) shouldEqual - """MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | refund - | MethodParameterPart - | BodyStatement - | CompoundStatement - | EnsureClause - | ensure - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | redirect_to - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | paddle_charge_path - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | @charge - | ) - | end""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/HashLiteralTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/HashLiteralTests.scala deleted file mode 100644 index f6e0c6d68fee..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/HashLiteralTests.scala +++ /dev/null @@ -1,129 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class HashLiteralTests extends RubyParserAbstractTest { - - "A standalone hash literal" should { - - "be parsed as a primary expression" when { - - "it contains no elements" in { - val code = "{ }" - printAst(_.primary(), code) shouldEqual - """HashConstructorPrimary - | HashConstructor - | { - | }""".stripMargin - } - - "it contains a single splatting identifier" in { - val code = "{ **x }" - printAst(_.primary(), code) shouldEqual - """HashConstructorPrimary - | HashConstructor - | { - | HashConstructorElements - | HashConstructorElement - | ** - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | }""".stripMargin - } - - "it contains two consecutive splatting identifiers" in { - val code = "{**x, **y}" - printAst(_.primary(), code) shouldEqual - """HashConstructorPrimary - | HashConstructor - | { - | HashConstructorElements - | HashConstructorElement - | ** - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | , - | HashConstructorElement - | ** - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | }""".stripMargin - - } - - "it contains an association between two splatting identifiers" in { - val code = "{**x, y => 1, **z}" - printAst(_.primary(), code) shouldEqual - """HashConstructorPrimary - | HashConstructor - | { - | HashConstructorElements - | HashConstructorElement - | ** - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | , - | HashConstructorElement - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | => - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | , - | HashConstructorElement - | ** - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | z - | }""".stripMargin - } - - "it contains a single splatting method invocation" in { - val code = "{**group_by_type(some)}" - printAst(_.primary(), code) shouldEqual - """HashConstructorPrimary - | HashConstructor - | { - | HashConstructorElements - | HashConstructorElement - | ** - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | group_by_type - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | some - | ) - | }""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithParenthesesTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithParenthesesTests.scala deleted file mode 100644 index baaded7a1f47..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithParenthesesTests.scala +++ /dev/null @@ -1,311 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class InvocationWithParenthesesTests extends RubyParserAbstractTest { - - "A method invocation with parentheses" should { - - "be parsed as a primary expression" when { - - "it contains no arguments" in { - val code = "foo()" - - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | BlankArgsArgumentsWithParentheses - | ( - | )""".stripMargin - } - - "it contains no arguments but has newline in between" in { - val code = - """foo( - |) - |""".stripMargin - - printAst(_.primary(), code) shouldEqual - s"""InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | BlankArgsArgumentsWithParentheses - | ( - | )""".stripMargin - } - - "it contains a single numeric literal positional argument" in { - val code = "foo(1)" - - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | )""".stripMargin - } - - "it contains a single numeric literal keyword argument" in { - val code = "foo(region: 1)" - - printAst(_.primary(), code) shouldEqual - s"""InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | AssociationArgument - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | region - | : - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | )""".stripMargin - } - - "it contains an identifier keyword argument" in { - val code = "foo(region:region)" - - printAst(_.primary(), code) shouldEqual - s"""InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | AssociationArgument - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | region - | : - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | region - | )""".stripMargin - } - - "it contains a non-empty regex literal keyword argument" in { - val code = "foo(id: /.*/)" - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | AssociationArgument - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | id - | : - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | .* - | / - | )""".stripMargin - } - - "it contains a single symbol literal positional argument" in { - val code = "foo(:region)" - - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | SymbolLiteral - | Symbol - | :region - | )""".stripMargin - } - - "it contains a single symbol literal positional argument and trailing comma" in { - val code = "foo(:region,)" - - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | SymbolLiteral - | Symbol - | :region - | , - | )""".stripMargin - } - - "it contains a splatting expression before a keyword argument" in { - val code = "foo(*x, y: 1)" - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | SplattingArgumentArgument - | SplattingArgument - | * - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | , - | AssociationArgument - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | : - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | )""".stripMargin - } - - "it contains a keyword-named keyword argument" in { - val code = "foo(if: true)" - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | AssociationArgument - | Association - | Keyword - | if - | : - | PrimaryExpression - | VariableReferencePrimary - | PseudoVariableIdentifierVariableReference - | TruePseudoVariableIdentifier - | true - | )""".stripMargin - } - - "it contains a safe navigation operator with no parameters" in { - val code = "foo&.bar()" - printAst(_.primary(), code) shouldEqual - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | &. - | MethodName - | MethodIdentifier - | bar - | BlankArgsArgumentsWithParentheses - | ( - | )""".stripMargin - } - - "it contains a safe navigation operator with non-zero parameters" in { - val code = "foo&.bar(1, 2)" - printAst(_.primary(), code) shouldEqual - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | &. - | MethodName - | MethodIdentifier - | bar - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | , - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | )""".stripMargin - } - - "it spans two lines, with the second line starting with `.`" in { - val code = "foo\n .bar" - printAst(_.primary(), code) shouldEqual - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | . - | MethodName - | MethodIdentifier - | bar""".stripMargin - } - - "it spans two lines, with the first line ending with `.`" in { - val code = "foo.\n bar" - printAst(_.primary(), code) shouldEqual - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | . - | MethodName - | MethodIdentifier - | bar""".stripMargin - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithoutParenthesesTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithoutParenthesesTests.scala deleted file mode 100644 index b5d3c7b03776..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/InvocationWithoutParenthesesTests.scala +++ /dev/null @@ -1,176 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class InvocationWithoutParenthesesTests extends RubyParserAbstractTest { - - "A method invocation without parentheses" should { - - "be parsed as a primary expression" when { - - "it contains a keyword?-named member" in { - val code = "task.nil?" - - printAst(_.primary(), code) shouldBe - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | task - | . - | MethodName - | MethodIdentifier - | MethodOnlyIdentifier - | Keyword - | nil - | ?""".stripMargin - } - - "it is keyword?-named" in { - val code = "do?" - - printAst(_.primary(), code) shouldBe - """MethodOnlyIdentifierPrimary - | MethodOnlyIdentifier - | Keyword - | do - | ?""".stripMargin - } - - "it is keyword!-named" in { - val code = "return!" - - printAst(_.primary(), code) shouldBe - """MethodOnlyIdentifierPrimary - | MethodOnlyIdentifier - | Keyword - | return - | !""".stripMargin - } - } - } - - "A command with do block" should { - - "be parsed as a statement" when { - - "it contains only one argument" in { - val code = """it 'should print 1' do - | puts 1 - |end - |""".stripMargin - - printAst(_.statement(), code) shouldBe - """ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | ChainedCommandDoBlockInvocationWithoutParentheses - | ChainedCommandWithDoBlock - | ArgsAndDoBlockAndMethodIdCommandWithDoBlock - | MethodIdentifier - | it - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | StringExpressionPrimary - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'should print 1' - | DoBlock - | do - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | end""".stripMargin - - } - - "it contains a safe navigation operator with no parameters" in { - val code = "foo&.bar" - printAst(_.primary(), code) shouldEqual - """ChainedInvocationPrimary - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | &. - | MethodName - | MethodIdentifier - | bar""".stripMargin - } - - "it contains a safe navigation operator with non-zero parameters" in { - val code = "foo&.bar 1,2" - printAst(_.command(), code) shouldEqual - """MemberAccessCommand - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | &. - | MethodName - | MethodIdentifier - | bar - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | , - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2""".stripMargin - } - - } - } - - "invocation with association arguments" should { - "have correct structure for association arguments" in { - val code = """foo bar:""" - printAst(_.program(), code) shouldBe - """Program - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | foo - | ArgumentsWithoutParentheses - | Arguments - | AssociationArgument - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | bar - | : - | EOF""".stripMargin - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/MethodDefinitionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/MethodDefinitionTests.scala deleted file mode 100644 index d055e207d37f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/MethodDefinitionTests.scala +++ /dev/null @@ -1,933 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class MethodDefinitionTests extends RubyParserAbstractTest { - - "A one-line empty method definition" should { - - "be parsed as a primary expression" when { - - "it contains no parameters" in { - val code = "def foo; end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains a mandatory parameter" in { - val code = "def foo(x);end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains an optional numeric parameter" in { - val code = "def foo(x=1);end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | OptionalParameter - | x - | = - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains two parameters, the last of which a &-parameter" in { - val code = "def foo(x, &y); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | , - | Parameter - | ProcParameter - | & - | y - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains a named (array) splatting argument" in { - val code = "def foo(*arr); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | ArrayParameter - | * - | arr - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains a named (hash) splatting argument" in { - val code = "def foo(**hash); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | HashParameter - | ** - | hash - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains both a named array and hash splatting argument" in { - val code = "def foo(*arr, **hash); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | ArrayParameter - | * - | arr - | , - | Parameter - | HashParameter - | ** - | hash - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains an optional parameter before a mandatory one" in { - val code = "def foo(x=1,y); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | OptionalParameter - | x - | = - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | , - | Parameter - | MandatoryParameter - | y - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains a keyword parameter" in { - val code = "def foo(x: 1); end" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | x - | : - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains a mandatory keyword parameter" in { - val code = "def foo(x:) ; end" - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | x - | : - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains two mandatory keyword parameters" in { - val code = "def foo(name:, surname:) ; end" - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | name - | : - | , - | Parameter - | KeywordParameter - | surname - | : - | ) - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - } - } - - "A multi-line method definition" should { - - "be parsed as a primary expression" when { - - "it contains a `rescue` clause" in { - val code = """def foo - | 1/0 - | rescue ZeroDivisionError => e - |end""".stripMargin - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultiplicativeExpression - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | / - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | RescueClause - | rescue - | ExceptionClass - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | ZeroDivisionError - | ExceptionVariableAssignment - | => - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | e - | ThenClause - | CompoundStatement - | end""".stripMargin - - } - } - - } - - "An endless method definition" should { - - "be parsed as a primary expression" when { - - "it contains no arguments" in { - val code = "def foo = x" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | MethodIdentifier - | foo - | MethodParameterPart - | = - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x""".stripMargin - } - - "it contains a line break right after `=`" in { - val code = "def foo =\n x" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | MethodIdentifier - | foo - | MethodParameterPart - | = - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x""".stripMargin - } - - "it contains no arguments and a string literal on the RHS" in { - val code = """def foo = "something"""" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | MethodIdentifier - | foo - | MethodParameterPart - | = - | PrimaryExpression - | StringExpressionPrimary - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | something - | """".stripMargin - } - - "it contains a single mandatory argument" in { - val code = "def id(x) = x" - printAst(_.primary(), code) shouldEqual - """MethodDefinitionPrimary - | MethodDefinition - | def - | MethodIdentifier - | id - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | ) - | = - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x""".stripMargin - } - } - - "not be recognized" when { - - // This test exists to make sure that `foo2=` is not parsed as an endless method, as - // endless methods cannot end in `=`. - "its name ends in `=`" in { - val code = - """def foo1 - |end - |def foo2=(arg) - |end - |""".stripMargin - - printAst(_.compoundStatement(), code) shouldEqual - """CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo1 - | MethodParameterPart - | BodyStatement - | CompoundStatement - | end - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | AssignmentLikeMethodIdentifier - | foo2= - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | arg - | ) - | BodyStatement - | CompoundStatement - | end""".stripMargin - - } - - // This test makes sure that the `end` after `def foo2=` is not parsed as part of its definition, - // which could happen if `foo2=` was parsed as two separate tokens (LOCAL_VARIABLE_IDENTIFIER, EQ) - // instead of just ASSIGNMENT_LIKE_METHOD_IDENTIFIER. - // Issue report: https://github.com/joernio/joern/issues/3270 - "its name ends in `=` and the next keyword is `end`" in { - val code = - """module SomeModule - |def foo1 - | return unless true - |end - |def foo2=(arg) - |end - |end - |""".stripMargin - printAst(_.compoundStatement(), code) shouldEqual - """CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | ModuleDefinitionPrimary - | ModuleDefinition - | module - | ClassOrModuleReference - | SomeModule - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo1 - | MethodParameterPart - | BodyStatement - | CompoundStatement - | Statements - | ModifierStatement - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | ReturnArgsInvocationWithoutParentheses - | return - | unless - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | PseudoVariableIdentifierVariableReference - | TruePseudoVariableIdentifier - | true - | end - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | AssignmentLikeMethodIdentifier - | foo2= - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | arg - | ) - | BodyStatement - | CompoundStatement - | end - | end""".stripMargin - } - } - } - - "method definition with proc parameters" should { - "have correct structure for proc parameters with name" in { - val code = - """def foo(&block) - | yield - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | ProcParameter - | & - | block - | ) - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | YieldWithOptionalArgumentPrimary - | YieldWithOptionalArgument - | yield - | end""".stripMargin - } - - "have correct structure for proc parameters with no name" in { - val code = - """def foo(&) - | yield - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | ProcParameter - | & - | ) - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | YieldWithOptionalArgumentPrimary - | YieldWithOptionalArgument - | yield - | end""".stripMargin - } - } - - "method definition for mandatory parameters" should { - "have correct structure for mandatory parameters" in { - val code = "def foo(bar:) end" - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | bar - | : - | ) - | BodyStatement - | CompoundStatement - | end""".stripMargin - } - - "have correct structure for a hash created using a method" in { - val code = - """def data - | { - | first_link:, - | action_link_group:, - | } - |end""".stripMargin - - printAst(_.primary(), code) shouldBe - """MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | data - | MethodParameterPart - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | HashConstructorPrimary - | HashConstructor - | { - | HashConstructorElements - | HashConstructorElement - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | first_link - | : - | , - | HashConstructorElement - | Association - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | action_link_group - | : - | , - | } - | end""".stripMargin - } - - "have correct structure when a method parameter is defined using whitespace" in { - val code = - """class SampleClass - | def sample_method( first_param:, second_param:) - | end - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | ClassOrModuleReference - | SampleClass - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | sample_method - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | first_param - | : - | , - | Parameter - | KeywordParameter - | second_param - | : - | ) - | BodyStatement - | CompoundStatement - | end - | end""".stripMargin - } - - "have correct structure when method parameters are defined using new line" in { - val code = - """class SomeClass - | def initialize( - | name, age) - | end - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | ClassOrModuleReference - | SomeClass - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | initialize - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | name - | , - | Parameter - | MandatoryParameter - | age - | ) - | BodyStatement - | CompoundStatement - | end - | end""".stripMargin - } - - "have correct structure when method parameters are defined using wsOrNL" in { - val code = - """class SomeClass - | def initialize( - | name, age - | ) - | end - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | ClassOrModuleReference - | SomeClass - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | initialize - | MethodParameterPart - | ( - | Parameters - | Parameter - | MandatoryParameter - | name - | , - | Parameter - | MandatoryParameter - | age - | ) - | BodyStatement - | CompoundStatement - | end - | end""".stripMargin - } - - "have correct structure when keyword parameters are defined using wsOrNL" in { - val code = - """class SomeClass - | def initialize( - | name: nil, age - | ) - | end - |end - |""".stripMargin - - printAst(_.primary(), code) shouldBe - """ClassDefinitionPrimary - | ClassDefinition - | class - | ClassOrModuleReference - | SomeClass - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | MethodDefinitionPrimary - | MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | initialize - | MethodParameterPart - | ( - | Parameters - | Parameter - | KeywordParameter - | name - | : - | PrimaryExpression - | VariableReferencePrimary - | PseudoVariableIdentifierVariableReference - | NilPseudoVariableIdentifier - | nil - | , - | Parameter - | MandatoryParameter - | age - | ) - | BodyStatement - | CompoundStatement - | end - | end""".stripMargin - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ModuleTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ModuleTests.scala deleted file mode 100644 index 8b8f7e64cfbf..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ModuleTests.scala +++ /dev/null @@ -1,24 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class ModuleTests extends RubyParserAbstractTest { - - "Empty module definition" should { - - "be parsed as a definition" when { - - "defined in a single line" in { - val code = """module Bar; end""" - printAst(_.moduleDefinition(), code) shouldEqual - """ModuleDefinition - | module - | ClassOrModuleReference - | Bar - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ProcDefinitionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ProcDefinitionTests.scala deleted file mode 100644 index 5466e3bed5f4..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ProcDefinitionTests.scala +++ /dev/null @@ -1,214 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class ProcDefinitionTests extends RubyParserAbstractTest { - - "A one-line proc definition" should { - - "be parsed as a primary expression" when { - - "it contains no parameters and no statements in a brace block" in { - val code = "-> {}" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | BraceBlockBlock - | BraceBlock - | { - | BodyStatement - | CompoundStatement - | }""".stripMargin - } - - "it contains no parameters and no statements in a do block" in { - val code = "-> do ; end" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | DoBlockBlock - | DoBlock - | do - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains no parameters and returns a literal in a do block" in { - val code = "-> do 1 end" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | DoBlockBlock - | DoBlock - | do - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | end""".stripMargin - } - - "it contains a mandatory parameter and no statements in a brace block" in { - val code = "-> (x) {}" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | ) - | BraceBlockBlock - | BraceBlock - | { - | BodyStatement - | CompoundStatement - | }""".stripMargin - } - - "it contains a mandatory parameter and no statements in a do block" in { - val code = "-> (x) do ; end" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | ) - | DoBlockBlock - | DoBlock - | do - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains an optional numeric parameter and no statements in a brace block" in { - val code = "->(x = 1) {}" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | ( - | Parameters - | Parameter - | OptionalParameter - | x - | = - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ) - | BraceBlockBlock - | BraceBlock - | { - | BodyStatement - | CompoundStatement - | }""".stripMargin - } - - "it contains a keyword parameter and no statements in a do block" in { - val code = "-> (foo: 1) do ; end" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | ( - | Parameters - | Parameter - | KeywordParameter - | foo - | : - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | ) - | DoBlockBlock - | DoBlock - | do - | BodyStatement - | CompoundStatement - | ; - | end""".stripMargin - } - - "it contains two mandatory parameters and two puts statements in a brace block" in { - val code = "->(x, y) {puts x; puts y}" - printAst(_.primary(), code) shouldBe - """ProcDefinitionPrimary - | ProcDefinition - | -> - | ( - | Parameters - | Parameter - | MandatoryParameter - | x - | , - | Parameter - | MandatoryParameter - | y - | ) - | BraceBlockBlock - | BraceBlock - | { - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | ; - | ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | }""".stripMargin - } - - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexTests.scala deleted file mode 100644 index e62be02ad99f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RegexTests.scala +++ /dev/null @@ -1,497 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class RegexTests extends RubyParserAbstractTest { - - "An empty regex literal" when { - - "by itself" should { - val code = "//" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """LiteralPrimary - | RegularExpressionLiteral - | / - | /""".stripMargin - } - } - - "on the RHS of an assignment" should { - val code = "x = //" - - "be parsed as a single assignment to a regex literal" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | SingleAssignmentExpression - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | x - | = - | MultipleRightHandSide - | ExpressionOrCommands - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | /""".stripMargin - } - } - - "as the argument to a `puts` command" should { - val code = "puts //" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | /""".stripMargin - } - } - - "as the sole argument to a parenthesized invocation" should { - val code = "puts(//)" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | puts - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | / - | )""".stripMargin - } - } - - "as the second argument to a parenthesized invocation" should { - val code = "puts(1, //)" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | puts - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | , - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | / - | )""".stripMargin - } - } - - "used in a `when` clause" should { - val code = - """case foo - | when /^ch_/ - | bar - |end""".stripMargin - - "be parsed as such" in { - printAst(_.primary(), code) shouldEqual - """CaseExpressionPrimary - | CaseExpression - | case - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | WhenClause - | when - | WhenArgument - | Expressions - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | ^ch_ - | / - | ThenClause - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | bar - | end""".stripMargin - } - - "used in a `unless` clause" should { - val code = - """unless /\A([^@\s]+)@((?:[-a-z0-9]+\.)+[a-z]{2,})\z/i.match?(value) - |end""".stripMargin - - "be parsed as such" in { - printAst(_.primary(), code) shouldEqual - """UnlessExpressionPrimary - | UnlessExpression - | unless - | ExpressionExpressionOrCommand - | PrimaryExpression - | ChainedInvocationPrimary - | LiteralPrimary - | RegularExpressionLiteral - | / - | \A([^@\s]+)@((?:[-a-z0-9]+\.)+[a-z]{2,})\z - | /i - | . - | MethodName - | MethodIdentifier - | MethodOnlyIdentifier - | match - | ? - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | value - | ) - | ThenClause - | CompoundStatement - | end""".stripMargin - } - } - } - } - - "A non-interpolated regex literal" when { - - "by itself" should { - val code = "/(eu|us)/" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """LiteralPrimary - | RegularExpressionLiteral - | / - | (eu|us) - | /""".stripMargin - } - } - - "on the RHS of an assignment" should { - val code = "x = /(eu|us)/" - - "be parsed as a single assignment to a regex literal" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | SingleAssignmentExpression - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | x - | = - | MultipleRightHandSide - | ExpressionOrCommands - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | (eu|us) - | /""".stripMargin - } - } - - "as the argument to a `puts` command" should { - val code = "puts /(eu|us)/" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | (eu|us) - | /""".stripMargin - } - } - - "as the argument to an parenthesized invocation" should { - val code = "puts(/(eu|us)/)" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | puts - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | RegularExpressionLiteral - | / - | (eu|us) - | / - | )""".stripMargin - } - } - } - - "A quoted non-interpolated (`%r`) regex literal" when { - - "by itself and using the `{`-`}` delimiters" should { - - "be parsed as a primary expression" in { - val code = "%r{a-z}" - printAst(_.primary(), code) shouldEqual - """QuotedRegexInterpolationPrimary - | QuotedRegexInterpolation - | %r{ - | a-z - | }""".stripMargin - } - } - - "by itself and using the `<`-`>` delimiters" should { - - "be parsed as a primary expression" in { - val code = "%r" - printAst(_.primary(), code) shouldEqual - """QuotedRegexInterpolationPrimary - | QuotedRegexInterpolation - | %r< - | eu|us - | >""".stripMargin - } - } - - "by itself, empty and using the `[`-`]` delimiters" should { - - "be parsed as a primary expression" in { - val code = "%r[]" - printAst(_.primary(), code) shouldEqual - """QuotedRegexInterpolationPrimary - | QuotedRegexInterpolation - | %r[ - | ]""".stripMargin - } - } - - } - - "A (numeric literal)-interpolated regex literal" when { - - "by itself" should { - val code = "/x#{1}y/" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """RegexInterpolationPrimary - | RegexInterpolation - | / - | x - | InterpolatedRegexSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | y - | /""".stripMargin - } - } - - "on the RHS of an assignment" should { - val code = "x = /x#{1}y/" - - "be parsed as a single assignment to a regex literal" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | SingleAssignmentExpression - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | x - | = - | MultipleRightHandSide - | ExpressionOrCommands - | ExpressionExpressionOrCommand - | PrimaryExpression - | RegexInterpolationPrimary - | RegexInterpolation - | / - | x - | InterpolatedRegexSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | y - | /""".stripMargin - } - } - - "as the argument to a `puts` command" should { - val code = "puts /x#{1}y/" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | RegexInterpolationPrimary - | RegexInterpolation - | / - | x - | InterpolatedRegexSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | y - | /""".stripMargin - } - } - - "as the argument to an parenthesized invocation" should { - val code = "puts(/x#{1}y/)" - - "be parsed as such" in { - printAst(_.expressionOrCommand(), code) shouldEqual - """ExpressionExpressionOrCommand - | PrimaryExpression - | InvocationWithParenthesesPrimary - | MethodIdentifier - | puts - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | RegexInterpolationPrimary - | RegexInterpolation - | / - | x - | InterpolatedRegexSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | y - | / - | )""".stripMargin - } - } - } - - "An interpolated quoted (`%r`) regex" when { - - "by itself, containing a numeric literal interpolation and text" should { - - "be parsed as a primary expression" in { - val code = """%r{x#{0}|y}""" - printAst(_.primary(), code) shouldEqual - """QuotedRegexInterpolationPrimary - | QuotedRegexInterpolation - | %r{ - | x - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | } - | |y - | }""".stripMargin - - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RescueClauseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RescueClauseTests.scala deleted file mode 100644 index fd40cf0ed07a..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RescueClauseTests.scala +++ /dev/null @@ -1,181 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class RescueClauseTests extends RubyParserAbstractTest { - - "A rescue statement" should { - - "be parsed as a standalone statement" when { - - "in the immediate scope of a `begin` block" in { - val code = - """begin - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin - - printAst(_.beginExpression(), code) shouldEqual - """BeginExpression - | begin - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultiplicativeExpression - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | / - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | RescueClause - | rescue - | ExceptionClass - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | ZeroDivisionError - | ExceptionVariableAssignment - | => - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | e - | ThenClause - | CompoundStatement - | end""".stripMargin - } - - "in the immediate scope of a `def` block" in { - val code = - """def foo; - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin - - printAst(_.methodDefinition(), code) shouldEqual - """MethodDefinition - | def - | SimpleMethodNamePart - | DefinedMethodName - | MethodName - | MethodIdentifier - | foo - | MethodParameterPart - | BodyStatement - | CompoundStatement - | ; - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultiplicativeExpression - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | / - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | RescueClause - | rescue - | ExceptionClass - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | ZeroDivisionError - | ExceptionVariableAssignment - | => - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | e - | ThenClause - | CompoundStatement - | end""".stripMargin - } - - "in the immediate scope of a `do` block" in { - val code = - """foo x do |y| - |y/0 - |rescue ZeroDivisionError => e - |end""".stripMargin - - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | ChainedCommandDoBlockInvocationWithoutParentheses - | ChainedCommandWithDoBlock - | ArgsAndDoBlockAndMethodIdCommandWithDoBlock - | MethodIdentifier - | foo - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | DoBlock - | do - | BlockParameter - | | - | BlockParameters - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | y - | | - | BodyStatement - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | MultiplicativeExpression - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | / - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | RescueClause - | rescue - | ExceptionClass - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | ZeroDivisionError - | ExceptionVariableAssignment - | => - | VariableIdentifierOnlySingleLeftHandSide - | VariableIdentifier - | e - | ThenClause - | CompoundStatement - | end""".stripMargin - } - } - - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ReturnTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ReturnTests.scala deleted file mode 100644 index f2df5706fa48..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/ReturnTests.scala +++ /dev/null @@ -1,65 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class ReturnTests extends RubyParserAbstractTest { - - "A standalone return statement" should { - - "be parsed as statement" when { - - "it contains no arguments" in { - val code = "return" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | ReturnArgsInvocationWithoutParentheses - | return""".stripMargin - - } - - "it contains a scoped chain invocation" in { - val code = "return ::X.y()" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | ReturnArgsInvocationWithoutParentheses - | return - | Arguments - | ExpressionArgument - | PrimaryExpression - | ChainedInvocationPrimary - | SimpleScopedConstantReferencePrimary - | :: - | X - | . - | MethodName - | MethodIdentifier - | y - | BlankArgsArgumentsWithParentheses - | ( - | )""".stripMargin - } - - "it contains arguments in parentheses" in { - val code = "return(0)" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | ReturnWithParenthesesPrimary - | return - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 0 - | )""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyLexerTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyLexerTests.scala deleted file mode 100644 index 51f5364d4ef7..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyLexerTests.scala +++ /dev/null @@ -1,1315 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyLexer.* -import io.joern.rubysrc2cpg.deprecated.parser.DeprecatedRubyLexerPostProcessor -import org.antlr.v4.runtime.* -import org.antlr.v4.runtime.Token.EOF -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers - -class RubyLexerTests extends AnyFlatSpec with Matchers { - - class RubySyntaxErrorListener extends BaseErrorListener { - var errors = 0 - override def syntaxError( - recognizer: Recognizer[?, ?], - offendingSymbol: Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException - ): Unit = - errors += 1 - } - - private def tokenizer(code: String, postProcessor: TokenSource => TokenSource): Iterable[Int] = { - val lexer = new DeprecatedRubyLexer(CharStreams.fromString(code)) - val syntaxErrorListener = new RubySyntaxErrorListener - lexer.addErrorListener(syntaxErrorListener) - val stream = new CommonTokenStream(postProcessor(lexer)) - stream.fill() // Run the lexer - if (syntaxErrorListener.errors > 0) { - Seq() - } else { - import scala.jdk.CollectionConverters.CollectionHasAsScala - stream.getTokens.asScala.map(_.getType) - } - } - - def tokenize(code: String): Iterable[Int] = tokenizer(code, identity) - - def tokenizeOpt(code: String): Iterable[Int] = tokenizer(code, DeprecatedRubyLexerPostProcessor.apply) - - "Single-line comments" should "be discarded" in { - val code = - """ - |# comment 1 - | #comment 2 - |## comment 3 - |""".stripMargin - - tokenize(code) shouldBe Seq(NL, NL, WS, NL, NL, EOF) - } - - "Multi-line comments" should "only be allowed if they start and end on the first column" in { - val code = - """ - |=begin Everything delimited by this =begin..=end - |block is ignored. - | =end - |=end This is the real closing delimiter - |""".stripMargin - - tokenize(code) shouldBe Seq(NL, EOF) - } - - "End-of-program marker (`__END__`)" should "only be recognized if it's on a line of its own" in { - val code = - """ - |# not valid: - |__END__ # comment - | __END__ - |# valid: - |__END__ - |This is now part of the data section and thus removed from the - |main lexer channel. Even __END__ is removed from the main channel. - |""".stripMargin - - tokenize(code) shouldBe Seq(NL, NL, LOCAL_VARIABLE_IDENTIFIER, WS, NL, WS, LOCAL_VARIABLE_IDENTIFIER, NL, NL, EOF) - } - - "Prefixed decimal integer literals" should "be recognized as such" in { - val eg = Seq("0d123456789", "0d1_2_3", "0d0") - all(eg.map(tokenize)) shouldBe Seq(DECIMAL_INTEGER_LITERAL, EOF) - } - - "Non-prefixed decimal integer literals" should "be recognized as such" in { - val eg = Seq("123456789", "1_2_3", "0") - all(eg.map(tokenize)) shouldBe Seq(DECIMAL_INTEGER_LITERAL, EOF) - } - - "Non-prefixed octal integer literals" should "be not be mistaken for decimal integer literals" in { - val eg = Seq("07", "01_2", "01", "0_123", "00") - all(eg.map(tokenize)) shouldBe Seq(OCTAL_INTEGER_LITERAL, EOF) - } - - "Prefixed octal integer literals" should "be recognized as such" in { - val eg = Seq("0o0", "0o1_7", "0o1_2_3") - all(eg.map(tokenize)) shouldBe Seq(OCTAL_INTEGER_LITERAL, EOF) - } - - "Binary integer literals" should "be recognized as such" in { - val eg = Seq("0b0", "0b1", "0b11", "0b1_0", "0b0_1_0") - all(eg.map(tokenize)) shouldBe Seq(BINARY_INTEGER_LITERAL, EOF) - } - - "Hexadecimal integer literals" should "be recognized as such" in { - val eg = Seq("0xA", "0x0_f1", "0x0abcFF_8") - all(eg.map(tokenize)) shouldBe Seq(HEXADECIMAL_INTEGER_LITERAL, EOF) - } - - "Floating-point literals without exponent" should "be recognized as such" in { - val eg = Seq("0.0", "1_2.2_1") - all(eg.map(tokenize)) shouldBe Seq(FLOAT_LITERAL_WITHOUT_EXPONENT, EOF) - } - - "Floating-point literals with exponent" should "be recognized as such" in { - val eg = Seq("0e0", "1E+10", "12e-10", "1.2e4") - all(eg.map(tokenize)) shouldBe Seq(FLOAT_LITERAL_WITH_EXPONENT, EOF) - } - - "Keyword-named symbols" should "be recognized as such" in { - val eg = Seq(":while", ":def", ":if") - all(eg.map(tokenize)) shouldBe Seq(SYMBOL_LITERAL, EOF) - } - - "Operator-named symbols" should "be recognized as such" in { - val eg = Seq(":^", ":==", ":[]", ":[]=", ":+", ":%", ":**", ":>>", ":+@") - all(eg.map(tokenize)) shouldBe Seq(SYMBOL_LITERAL, EOF) - } - - "Assignment-like-named symbols" should "be recognized as such" in { - val eg = Seq(":X=", ":xyz=") - all(eg.map(tokenize)) shouldBe Seq(SYMBOL_LITERAL, EOF) - } - - "Local variable identifiers" should "be recognized as such" in { - val eg = Seq("i", "x1", "old_value", "_internal", "_while") - all(eg.map(tokenize)) shouldBe Seq(LOCAL_VARIABLE_IDENTIFIER, EOF) - } - - "Constant identifiers" should "be recognized as such" in { - val eg = Seq("PI", "Const") - all(eg.map(tokenize)) shouldBe Seq(CONSTANT_IDENTIFIER, EOF) - } - - "Global variable identifiers" should "be recognized as such" in { - val eg = Seq("$foo", "$Foo", "$_", "$__foo") - all(eg.map(tokenize)) shouldBe Seq(GLOBAL_VARIABLE_IDENTIFIER, EOF) - } - - "Instance variable identifiers" should "be recognized as such" in { - val eg = Seq("@x", "@_int", "@if", "@_", "@X0") - all(eg.map(tokenize)) shouldBe Seq(INSTANCE_VARIABLE_IDENTIFIER, EOF) - } - - "Class variable identifiers" should "be recognized as such" in { - val eg = Seq("@@counter", "@@if", "@@While_0") - all(eg.map(tokenize)) shouldBe Seq(CLASS_VARIABLE_IDENTIFIER, EOF) - } - - "Single-quoted string literals" should "be recognized as such" in { - val eg = Seq("''", "'\nx'", "'\\''", "'\\'\n\\\''") - all(eg.map(tokenize)) shouldBe Seq(SINGLE_QUOTED_STRING_LITERAL, EOF) - } - - "Non-interpolated, non-escaped double-quoted string literals" should "be recognized as such" in { - val eg = Seq("\"something\"", "\"x\n\"") - all(eg.map(tokenize)) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Double-quoted string literals containing identifier interpolations" should "be recognized as such" in { - val eg = Seq("\"$x = #$x\"", "\"@xyz = #@xyz\"", "\"@@counter = #@@counter\"") - all(eg.map(tokenize)) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - INTERPOLATED_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Double-quoted string literals containing escaped `#` characters" should "not be mistaken for interpolations" in { - val code = "\"x = \\#$x\"" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Double-quoted string literals containing `#`" should "not be mistaken for interpolations" in { - val code = "\"x = #\"" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Double-quoted string literals containing `\\u` character sequences" should "be recognized as such" in { - val code = """"AB\u0003\u0004\u0014\u0000\u0000\u0000\b\u0000\u0000\u0000!\u0000file"""" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Interpolated double-quoted string literal" should "be recognized as such" in { - val code = "\"x is #{1+1}\"" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - STRING_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - PLUS, - DECIMAL_INTEGER_LITERAL, - STRING_INTERPOLATION_END, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Recursively interpolated double-quoted string literal" should "be recognized as such" in { - val code = "\"x is #{\"#{1+1}\"}!\"" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - STRING_INTERPOLATION_BEGIN, - DOUBLE_QUOTED_STRING_START, - STRING_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - PLUS, - DECIMAL_INTEGER_LITERAL, - STRING_INTERPOLATION_END, - DOUBLE_QUOTED_STRING_END, - STRING_INTERPOLATION_END, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Escaped `\"` in double-quoted string literal" should "not be mistaken for end of string" in { - val code = "\"x is \\\"4\\\"\"" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Escaped `\\)` in double-quoted string literal" should "be recognized as a single character" in { - val code = """"\)"""" - tokenize(code) shouldBe Seq( - DOUBLE_QUOTED_STRING_START, - DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE, - DOUBLE_QUOTED_STRING_END, - EOF - ) - } - - "Escaped `\\)` in `%Q` string literal" should "be recognized as a single character" in { - val code = """%Q(\))""" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "Empty regex literal" should "be recognized as such" in { - val code = "//" - tokenize(code) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_END, EOF) - } - - "Empty regex literal on the RHS of an assignment" should "be recognized as such" in { - // This test exists to check if RubyLexer properly decided between SLASH and REGULAR_EXPRESSION_START - val code = "x = //" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - WS, - EQ, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Empty regex literal on the RHS of an association" should "be recognized as such" in { - val code = "{x: //}" - tokenize(code) shouldBe Seq( - LCURLY, - LOCAL_VARIABLE_IDENTIFIER, - COLON, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_END, - RCURLY, - EOF - ) - } - - "Non-empty regex literal on the RHS of a keyword argument" should "be recognized as such" in { - val code = "foo(x: /.*/)" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - LPAREN, - LOCAL_VARIABLE_IDENTIFIER, - COLON, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - RPAREN, - EOF - ) - } - - "Non-empty regex literal on the RHS of an assignment" should "be recognized as such" in { - val code = """NAME_REGEX = /\A[^0-9!\``@#\$%\^&*+_=]+\z/""" - tokenize(code) shouldBe Seq( - CONSTANT_IDENTIFIER, - WS, - EQ, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Non-empty regex literal on the RHS of an regex matching operation" should "be recognized as such" in { - val code = """content_filename =~ /filename="(.*)"/""" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - WS, - EQTILDE, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Non-empty regex literal after `when`" should "be recognized as such" in { - val code = "when /^ch_/" - tokenize(code) shouldBe Seq( - WHEN, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Non-empty regex literal after `unless`" should "be recognized as such" in { - val code = "unless /^ch_/" - tokenize(code) shouldBe Seq( - UNLESS, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Regex literals without metacharacters" should "be recognized as such" in { - val eg = Seq("/regexp/", "/a regexp/") - all(eg.map(tokenize)) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_BODY, REGULAR_EXPRESSION_END, EOF) - } - - "Regex literals with metacharacters" should "be recognized as such" in { - val eg = Seq("/(us|eu)/", "/[a-z]/", "/[A-Z]*/", "/(us|eu)?/", "/[0-9]+/") - all(eg.map(tokenize)) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_BODY, REGULAR_EXPRESSION_END, EOF) - } - - "Regex literals with character classes" should "be recognized as such" in { - val eg = Seq("/\\w/", "/\\W/", "/\\S/") - all(eg.map(tokenize)) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_BODY, REGULAR_EXPRESSION_END, EOF) - } - - "Regex literals with groups" should "be recognized as such" in { - val eg = Seq("/[aeiou]\\w{2}/", "/(\\d{2}:\\d{2}) (\\w+) (.*)/", "/(?\\w+) (?\\d+)/") - all(eg.map(tokenize)) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_BODY, REGULAR_EXPRESSION_END, EOF) - } - - "Regex literals with options" should "be recognized as such" in { - val eg = Seq("/./m", "/./i", "/./x", "/./o") - all(eg.map(tokenize)) shouldBe Seq(REGULAR_EXPRESSION_START, REGULAR_EXPRESSION_BODY, REGULAR_EXPRESSION_END, EOF) - } - - "Interpolated (with a local variable) regex literal" should "be recognized as such" in { - val code = "/#{foo}/" - tokenize(code) shouldBe Seq( - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_INTERPOLATION_BEGIN, - LOCAL_VARIABLE_IDENTIFIER, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Interpolated (with a numeric expression) regex literal" should "be recognized as such" in { - val code = "/#{1+1}/" - tokenize(code) shouldBe Seq( - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - PLUS, - DECIMAL_INTEGER_LITERAL, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Interpolated (with a local variable) regex literal containing also textual body elements" should "be recognized as such" in { - val code = "/x\\.#{foo}\\./" - tokenize(code) shouldBe Seq( - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_INTERPOLATION_BEGIN, - LOCAL_VARIABLE_IDENTIFIER, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Vacuously interpolated regex literal" should "be recognized as such" in { - val code = "/#{}/" - tokenize(code) shouldBe Seq( - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_INTERPOLATION_BEGIN, - REGULAR_EXPRESSION_INTERPOLATION_END, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Division operator between identifiers" should "not be confused with regex start" in { - val code = "x / y" - tokenize(code) shouldBe Seq(LOCAL_VARIABLE_IDENTIFIER, WS, SLASH, WS, LOCAL_VARIABLE_IDENTIFIER, EOF) - } - - "Addition between class fields" should "not be confused with +@ token" in { - // This test exists to check if RubyLexer properly decided between PLUS and PLUSAT - val code = "x+@y" - tokenize(code) shouldBe Seq(LOCAL_VARIABLE_IDENTIFIER, PLUS, INSTANCE_VARIABLE_IDENTIFIER, EOF) - } - - "Subtraction between class fields" should "not be confused with -@ token" in { - // This test exists to check if RubyLexer properly decided between MINUS and MINUSAT - val code = "x-@y" - tokenize(code) shouldBe Seq(LOCAL_VARIABLE_IDENTIFIER, MINUS, INSTANCE_VARIABLE_IDENTIFIER, EOF) - } - - "Invocation of command with regex literal" should "not be confused with binary division" in { - val code = "puts /x/" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - WS, - REGULAR_EXPRESSION_START, - REGULAR_EXPRESSION_BODY, - REGULAR_EXPRESSION_END, - EOF - ) - } - - "Multi-line string literal concatenation" should "be recognized as two string literals separated by whitespace" in { - val code = - """'abc' \ - |'cde'""".stripMargin - tokenize(code) shouldBe Seq(SINGLE_QUOTED_STRING_LITERAL, WS, SINGLE_QUOTED_STRING_LITERAL, EOF) - } - - "Multi-line string literal concatenation" should "be optimized as two consecutive string literals" in { - val code = - """'abc' \ - |'cde'""".stripMargin - tokenizeOpt(code) shouldBe Seq(SINGLE_QUOTED_STRING_LITERAL, SINGLE_QUOTED_STRING_LITERAL, EOF) - } - - "empty `%q` string literals" should "be recognized as such" in { - val eg = Seq("%q()", "%q[]", "%q{}", "%q<>", "%q##", "%q!!", "%q--", "%q@@", "%q++", "%q**", "%q//", "%q&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "single-character `%q` string literals" should "be recognized as such" in { - val eg = - Seq("%q(x)", "%q[y]", "%q{z}", "%q", "%q#a#", "%q!b!", "%q-_-", "%q@c@", "%q+d+", "%q*e*", "%q/#/", "%q&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - NON_EXPANDED_LITERAL_CHARACTER, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "delimiter-escaped-single-character `%q` string literals" should "be recognized as such" in { - val eg = Seq( - "%q(\\))", - "%q[\\]]", - "%q{\\}}", - "%q<\\>>", - "%q#\\##", - "%q!\\!!", - "%q-\\--", - "%q@\\@@", - "%q+\\++", - "%q*\\**", - "%q/\\//", - "%q&\\&&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - NON_EXPANDED_LITERAL_CHARACTER, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "nested `%q` string literals" should "be recognized as such" in { - val eg = Seq("%q(()())", "%q[[][]]", "%q{{}{}}", "%q<<><>>") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - NON_EXPANDED_LITERAL_CHARACTER, - NON_EXPANDED_LITERAL_CHARACTER, - NON_EXPANDED_LITERAL_CHARACTER, - NON_EXPANDED_LITERAL_CHARACTER, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "empty `%Q` string literals" should "be recognized as such" in { - val eg = Seq("%Q()", "%Q[]", "%Q{}", "%Q<>", "%Q##", "%Q!!", "%Q--", "%Q@@", "%Q++", "%Q**", "%Q//", "%Q&&") - all(eg.map(tokenize)) shouldBe Seq(QUOTED_EXPANDED_STRING_LITERAL_START, QUOTED_EXPANDED_STRING_LITERAL_END, EOF) - } - - "single-character `%Q` string literals" should "be recognized as such" in { - val eg = - Seq("%Q(x)", "%Q[y]", "%Q{z}", "%Q", "%Q#a#", "%Q!b!", "%Q-_-", "%Q@c@", "%Q+d+", "%Q*e*", "%Q/#/", "%Q&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "delimiter-escaped-single-character `%Q` string literals" should "be recognized as such" in { - val eg = Seq( - "%Q(\\))", - "%Q[\\]]", - "%Q{\\}}", - "%Q<\\>>", - "%Q#\\##", - "%Q!\\!!", - "%Q-\\--", - "%Q@\\@@", - "%Q+\\++", - "%Q*\\**", - "%Q/\\//", - "%Q&\\&&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "nested `%Q` string literals" should "be recognized as such" in { - val eg = Seq("%Q(()())", "%Q[[][]]", "%Q{{}{}}", "%Q<<><>>") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "interpolated (with a numeric expression) `%Q` string literals" should "be recognized as such" in { - val code = "%Q(#{1})" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - DELIMITED_STRING_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_STRING_INTERPOLATION_END, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%Q` string literals containing identifier interpolations" should "be recognized as such" in { - val eg = Seq("%Q[x = #$x]", "%Q{x = #@xyz}", "%Q") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_VARIABLE_CHARACTER_SEQUENCE, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%Q` string literals containing escaped `#` characters" should "not be mistaken for interpolations" in { - val code = """%Q(\#$x)""" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%Q` string literals containing `#`" should "not be mistaken for interpolations" in { - val code = "%Q[#]" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "empty `%(` string literals" should "be recognized as such" in { - val code = "%()" - tokenize(code) shouldBe Seq(QUOTED_EXPANDED_STRING_LITERAL_START, QUOTED_EXPANDED_STRING_LITERAL_END, EOF) - } - - "single-character `%(` string literals" should "be recognized as such" in { - val code = "%(-)" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "delimiter-escaped-single-character `%(` string literals" should "be recognized as such" in { - val code = "%(\\))" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "nested `%(` string literals" should "be recognized as such" in { - val code = "%(()())" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "interpolated (with a numeric expression) `%(` string literals" should "be recognized as such" in { - val code = "%(#{1})" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - DELIMITED_STRING_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_STRING_INTERPOLATION_END, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%(` string literals containing identifier interpolations" should "be recognized as such" in { - val eg = Seq("%(x = #$x)", "%(x = #@xyz)", "%(x = #@@counter)") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_VARIABLE_CHARACTER_SEQUENCE, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%(` after a decimal literal" should "not be mistaken for an expanded string literal" in { - val code = "100%(x+1)" - tokenize(code) shouldBe Seq( - DECIMAL_INTEGER_LITERAL, - PERCENT, - LPAREN, - LOCAL_VARIABLE_IDENTIFIER, - PLUS, - DECIMAL_INTEGER_LITERAL, - RPAREN, - EOF - ) - } - - "`%(` in a `puts` argument" should "be recognized as such" in { - val code = "puts %()" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - WS, - QUOTED_EXPANDED_STRING_LITERAL_START, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%(` string literals containing escaped `#` characters" should "not be mistaken for interpolations" in { - val code = """%(\#$x)""" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "`%(` string literals containing `#`" should "not be mistaken for interpolations" in { - val code = "%(#)" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "Empty `%x` literals" should "be recognized as such" in { - val eg = Seq("%x()", "%x[]", "%x{}", "%x<>", "%x##", "%x!!", "%x--", "%x@@", "%x++", "%x**", "%x//", "%x&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - EOF - ) - } - - "`%x` literals containing `#`" should "not be mistaken for interpolations" in { - val code = "%x[#]" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - EOF - ) - } - - "`%x` literals containing escaped `#` characters" should "not be mistaken for interpolations" in { - val code = """%x(\#$x)""" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - EOF - ) - } - - "`%x` literals containing identifier interpolations" should "be recognized as such" in { - val eg = Seq("%x[#$x]", "%x{#@xyz}", "%x<#@@counter>") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START, - EXPANDED_VARIABLE_CHARACTER_SEQUENCE, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - EOF - ) - } - - "Interpolated (with a local variable) `%x` literals" should "be recognized as such" in { - val code = "%x(#{ls})" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_START, - DELIMITED_STRING_INTERPOLATION_BEGIN, - LOCAL_VARIABLE_IDENTIFIER, - DELIMITED_STRING_INTERPOLATION_END, - QUOTED_EXPANDED_EXTERNAL_COMMAND_LITERAL_END, - EOF - ) - } - - "empty `%r` regex literals" should "be recognized as such" in { - val eg = Seq("%r()", "%r[]", "%r{}", "%r<>", "%r##", "%r!!", "%r--", "%r@@", "%r++", "%r**", "%r//", "%r&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_REGULAR_EXPRESSION_START, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - EOF - ) - } - - "single-character `%r` regex literals" should "be recognized as such" in { - val eg = - Seq("%r(x)", "%r[y]", "%r{z}", "%r", "%r#a#", "%r!b!", "%r-_-", "%r@c@", "%r+d+", "%r*e*", "%r/#/", "%r&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_REGULAR_EXPRESSION_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - EOF - ) - } - - "delimiter-escaped-single-character `%r` regex literals" should "be recognized as such" in { - val eg = Seq( - "%r(\\))", - "%r[\\]]", - "%r{\\}}", - "%r<\\>>", - "%r#\\##", - "%r!\\!!", - "%r-\\--", - "%r@\\@@", - "%r+\\++", - "%r*\\**", - "%r/\\//", - "%r&\\&&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_REGULAR_EXPRESSION_START, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - EOF - ) - } - - "nested `%r` regex literals" should "be recognized as such" in { - val eg = Seq("%r(()())", "%r[[][]]", "%r{{}{}}", "%r<<><>>") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_REGULAR_EXPRESSION_START, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - EXPANDED_LITERAL_CHARACTER, - QUOTED_EXPANDED_REGULAR_EXPRESSION_END, - EOF - ) - } - - "empty `%w` string array literals" should "be recognized as such" in { - val eg = Seq("%w()", "%w[]", "%w{}", "%w<>", "%w##", "%w!!", "%w--", "%w@@", "%w++", "%w**", "%w//", "%w&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single-character `%w` string array literals" should "be recognized as such" in { - val eg = - Seq("%w(x)", "%w[y]", "%w{z}", "%w", "%w#a#", "%w!b!", "%w-_-", "%w@c@", "%w+d+", "%w*e*", "%w/#/", "%w&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "two-word `%w` string array literals" should "be recognized as such" in { - val eg = Seq( - "%w(xx y)", - "%w[yy z]", - "%w{z0 w}", - "%w", - "%w#a& ?#", - "%w!b_ c!", - "%w-_= +-", - "%w@c\" d@", - "%w+d/ *+", - "%w*ef <*", - "%w/#< >/", - "%w&!! %&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single word `%w` string array literal containing an escaped whitespace" should "be recognized as such" in { - val code = """%w[x\ y]""" - tokenize(code) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "multi-line `%w` string array literal" should "be recognized as such" in { - val code = - """%w( - | bob - | cod - | dod)""".stripMargin - tokenize(code) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "empty `%W` string array literals" should "be recognized as such" in { - val eg = Seq("%W()", "%W[]", "%W{}", "%W<>", "%W##", "%W!!", "%W--", "%W@@", "%W++", "%W**", "%W//", "%W&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single-character `%W` string array literals" should "be recognized as such" in { - val eg = - Seq("%W(x)", "%W[y]", "%W{z}", "%W", "%W#a#", "%W!b!", "%W-_-", "%W@c@", "%W+d+", "%W*e*", "%W/#/", "%W&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "two-word `%W` string array literals" should "be recognized as such" in { - val eg = Seq( - "%W(xx y)", - "%W[yy z]", - "%W{z0 w}", - "%W", - "%W#a& ?#", - "%W!b_ c!", - "%W-_= +-", - "%W@c\" d@", - "%W+d/ *+", - "%W*ef <*", - "%W/#< >/", - "%W&!! %&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single interpolated word `%W` string array literal" should "be recognized as such" in { - val code = "%W{#{0}}" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single word `%W` string array literal containing text and an interpolated numeric" should "be recognized as such" in { - val code = "%W" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "two-word `%W` string array literal containing text and interpolated numerics" should "be recognized as such" in { - val code = "%W(x#{0} x#{1})" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "single word `%W` string array literal containing an escaped whitespace" should "be recognized as such" in { - val code = """%W[x\ y]""" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "multi-line `%W` string array literal" should "be recognized as such" in { - val code = - """%W( - | bob - | cod - | dod)""".stripMargin - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_EXPANDED_STRING_ARRAY_LITERAL_END, - EOF - ) - } - - "empty `%i` symbol array literals" should "be recognized as such" in { - val eg = Seq("%i()", "%i[]", "%i{}", "%i<>", "%i##", "%i!!", "%i--", "%i@@", "%i++", "%i**", "%i//", "%i&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "single-character `%i` symbol array literals" should "be recognized as such" in { - val eg = - Seq("%i(x)", "%i[y]", "%i{z}", "%i", "%i#a#", "%i!b!", "%i-_-", "%i@c@", "%i+d+", "%i*e*", "%i/#/", "%i&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "two-word `%i` symbol array literals" should "be recognized as such" in { - val eg = Seq( - "%i(xx y)", - "%i[yy z]", - "%i{z0 w}", - "%i", - "%i#a& ?#", - "%i!b_ c!", - "%i-_= +-", - "%i@c\" d@", - "%i+d/ *+", - "%i*ef <*", - "%i/#< >/", - "%i&!! %&" - ) - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "multi-line two-word `%i` symbol array literals" should "be recognized as such" in { - val code = - """%i( - |x - |y - |)""".stripMargin - tokenize(code) shouldBe Seq( - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - NON_EXPANDED_ARRAY_ITEM_CHARACTER, - NON_EXPANDED_ARRAY_ITEM_SEPARATOR, - QUOTED_NON_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "empty `%I` symbol array literals" should "be recognized as such" in { - val eg = Seq("%I()", "%I[]", "%I{}", "%I<>", "%I##", "%I!!", "%I--", "%I@@", "%I++", "%I**", "%I//", "%I&&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "single-character `%I` symbol array literals" should "be recognized as such" in { - val eg = - Seq("%I(x)", "%I[y]", "%I{z}", "%I", "%I#a#", "%I!b!", "%I-_-", "%I@c@", "%I+d+", "%I*e*", "%I/#/", "%I&!&") - all(eg.map(tokenize)) shouldBe Seq( - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "two-word `%I` symbol array literal containing text and interpolated numerics" should "be recognized as such" in { - val code = "%I(x#{0} x#{1})" - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_CHARACTER, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - DELIMITED_ARRAY_ITEM_INTERPOLATION_BEGIN, - DECIMAL_INTEGER_LITERAL, - DELIMITED_ARRAY_ITEM_INTERPOLATION_END, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "multi-line two-word `%I` symbol array literals" should "be recognized as such" in { - val code = - """%I( - |x - |y - |)""".stripMargin - tokenize(code) shouldBe Seq( - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_START, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_SEPARATOR, - EXPANDED_ARRAY_ITEM_CHARACTER, - EXPANDED_ARRAY_ITEM_SEPARATOR, - QUOTED_EXPANDED_SYMBOL_ARRAY_LITERAL_END, - EOF - ) - } - - "identifier used in a keyword argument" should "not be mistaken for a symbol literal" in { - // This test exists to check if RubyLexer properly decided between COLON and SYMBOL_LITERAL - val code = "foo(x:y)" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - LPAREN, - LOCAL_VARIABLE_IDENTIFIER, - COLON, - LOCAL_VARIABLE_IDENTIFIER, - RPAREN, - EOF - ) - } - - "instance variable used in a keyword argument" should "not be mistaken for a symbol literal" in { - // This test exists to check if RubyLexer properly decided between COLON and SYMBOL_LITERAL - val code = "foo(x:@y)" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - LPAREN, - LOCAL_VARIABLE_IDENTIFIER, - COLON, - INSTANCE_VARIABLE_IDENTIFIER, - RPAREN, - EOF - ) - } - - "operator-named symbol used in a whitespace-free `=>` association" should "not be include `=` as part of its name" in { - // This test exists to check if RubyLexer properly recognizes EQGT - val code = "{:x=>1}" - tokenize(code) shouldBe Seq(LCURLY, SYMBOL_LITERAL, EQGT, DECIMAL_INTEGER_LITERAL, RCURLY, EOF) - } - - "class variable used in a keyword argument" should "not be mistaken for a symbol literal" in { - // This test exists to check if RubyLexer properly decided between COLON and SYMBOL_LITERAL - val code = "foo(x:@@y)" - tokenize(code) shouldBe Seq( - LOCAL_VARIABLE_IDENTIFIER, - LPAREN, - LOCAL_VARIABLE_IDENTIFIER, - COLON, - CLASS_VARIABLE_IDENTIFIER, - RPAREN, - EOF - ) - } - - "Regex match global variables" should "be recognized as such" in { - val eg = Seq("$0", "$10", "$2", "$3") - all(eg.map(tokenize)) shouldBe Seq(GLOBAL_VARIABLE_IDENTIFIER, EOF) - } - - "Assignment-like method identifiers" should "be recognized as such" in { - val eg = Seq("def x=", "def X=") - all(eg.map(tokenize)) shouldBe Seq(DEF, WS, ASSIGNMENT_LIKE_METHOD_IDENTIFIER, EOF) - } - - "Unrecognized escape character" should "emit an UNRECOGNIZED token" in { - val code = "\\!" - tokenize(code) shouldBe Seq(UNRECOGNIZED, EMARK, EOF) - } - - "Single NON_EXPANDED_LITERAL_CHARACTER token" should "be rewritten into a single NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE token" in { - val code = "%q{ }" - tokenizeOpt(code) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "Consecutive NON_EXPANDED_LITERAL_CHARACTER tokens" should "be rewritten into a single NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE token" in { - val code = "%q(1 2 3 4)" - tokenizeOpt(code) shouldBe Seq( - QUOTED_NON_EXPANDED_STRING_LITERAL_START, - NON_EXPANDED_LITERAL_CHARACTER_SEQUENCE, - QUOTED_NON_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "Single EXPANDED_LITERAL_CHARACTER token" should "be rewritten into a single EXPANDED_LITERAL_CHARACTER_SEQUENCE token" in { - val code = "%Q( )" - tokenizeOpt(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER_SEQUENCE, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } - - "Consecutive EXPANDED_LITERAL_CHARACTER tokens" should "be rewritten into a single EXPANDED_LITERAL_CHARACTER_SEQUENCE token" in { - val code = "%Q{1 2 3 4 5}" - tokenizeOpt(code) shouldBe Seq( - QUOTED_EXPANDED_STRING_LITERAL_START, - EXPANDED_LITERAL_CHARACTER_SEQUENCE, - QUOTED_EXPANDED_STRING_LITERAL_END, - EOF - ) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyParserAbstractTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyParserAbstractTest.scala deleted file mode 100644 index b9754656ad84..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/RubyParserAbstractTest.scala +++ /dev/null @@ -1,31 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -import io.joern.rubysrc2cpg.parser.AstPrinter -import org.antlr.v4.runtime.{CharStreams, CommonTokenStream, ParserRuleContext} -import org.scalatest.matchers.should.Matchers -import org.scalatest.wordspec.AnyWordSpec - -import java.util.stream.Collectors - -// TODO: Should share the same lexer/token stream/parser as the frontend itself. -// See `io.joern.rubysrc2cpg.astcreation.AntlrParser` -abstract class RubyParserAbstractTest extends AnyWordSpec with Matchers { - - def rubyStream(code: String): CommonTokenStream = - new CommonTokenStream(DeprecatedRubyLexerPostProcessor(new DeprecatedRubyLexer(CharStreams.fromString(code)))) - - def rubyParser(code: String): DeprecatedRubyParser = - new DeprecatedRubyParser(rubyStream(code)) - - def printAst(withContext: DeprecatedRubyParser => ParserRuleContext, input: String): String = - omitWhitespaceLines(AstPrinter.print(withContext(rubyParser(input)))) - - private def omitWhitespaceLines(text: String): String = - text.lines().filter(_.strip().nonEmpty).collect(Collectors.joining("\n")) - - def accepts(withContext: DeprecatedRubyParser => ParserRuleContext, input: String): Boolean = { - val parser = rubyParser(input) - withContext(parser) - parser.getNumberOfSyntaxErrors == 0 - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/StringTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/StringTests.scala deleted file mode 100644 index 16f0df7d4db2..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/StringTests.scala +++ /dev/null @@ -1,671 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class StringTests extends RubyParserAbstractTest { - - "A single-quoted string literal" when { - - "empty" should { - val code = "''" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | SimpleStringExpression - | SingleQuotedStringLiteral - | ''""".stripMargin - } - } - - "separated by whitespace" should { - val code = "'x' 'y'" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'x' - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'y'""".stripMargin - } - } - - "separated by '\\\\n' " should { - val code = """'x' \ - | 'y'""".stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'x' - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'y'""".stripMargin - } - } - - "separated by '\\\\n' twice" should { - val code = - """'x' \ - | 'y' \ - | 'z'""".stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'x' - | ConcatenatedStringExpression - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'y' - | SimpleStringExpression - | SingleQuotedStringLiteral - | 'z'""".stripMargin - } - } - } - - "A non-expanded `%q` string literal" should { - - "be parsed as a primary expression" when { - - "it is empty and uses the `(`-`)` delimiters" in { - val code = "%q()" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | )""".stripMargin - } - - "it is empty and uses the `[`-`]` delimiters" in { - val code = "%q[]" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q[ - | ]""".stripMargin - } - - "it is empty and uses the `{`-`}` delimiters" in { - val code = "%q{}" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q{ - | }""".stripMargin - } - - "it is empty and uses the `<`-`>` delimiters" in { - val code = "%q<>" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q< - | >""".stripMargin - } - - "it is empty and uses the `#` delimiters" in { - val code = "%q##" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q# - | #""".stripMargin - } - - "it contains a single non-escaped character and uses the `(`-`)` delimiters" in { - val code = "%q(x)" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | x - | )""".stripMargin - } - - "it contains a single non-escaped character and uses the `[`-`]` delimiters" in { - val code = "%q[x]" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q[ - | x - | ]""".stripMargin - } - - "it contains a single non-escaped character and uses the `#` delimiters" in { - val code = "%q#x#" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q# - | x - | #""".stripMargin - } - - "it contains a single escaped character and uses the `(`-`)` delimiters" in { - val code = "%q(\\()" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | \( - | )""".stripMargin - } - - "it contains a single escaped character and uses the `[`-`]` delimiters" in { - val code = "%q[\\]]" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q[ - | \] - | ]""".stripMargin - } - - "it contains a single escaped character and uses the `#` delimiters" in { - val code = "%q#\\##" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q# - | \# - | #""".stripMargin - } - - "it contains a word and uses the `(`-`)` delimiters" in { - val code = "%q(foo)" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | foo - | )""".stripMargin - } - - "it contains an empty nested string using the `(`-`)` delimiters" in { - val code = "%q( () )" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | () - | )""".stripMargin - } - - "it contains an escaped single-character nested string using the `(`-`)` delimiters" in { - val code = "%q( (\\)) )" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q( - | (\)) - | )""".stripMargin - } - - "it contains an escaped single-character nested string using the `<`-`>` delimiters" in { - val code = "%q< <\\>> >" - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | NonExpandedQuotedStringLiteral - | %q< - | <\>> - | >""".stripMargin - } - } - } - - "An expanded `%Q` string literal" when { - - "empty" should { - val code = "%Q()" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %Q( - | )""".stripMargin - } - } - - "containing text and a numeric literal interpolation" should { - val code = "%Q{text=#{1}}" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %Q{ - | text= - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | }""".stripMargin - } - } - - "containing two consecutive numeric literal interpolations" should { - val code = "%Q[#{1}#{2}]" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %Q[ - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | } - | ]""".stripMargin - } - } - - } - - "An expanded `%(` string literal" when { - - "empty" should { - val code = "%()" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %( - | )""".stripMargin - } - } - - "containing text and a numeric literal interpolation" should { - val code = "%(text=#{1})" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %( - | text= - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | )""".stripMargin - } - } - - "containing two consecutive numeric literal interpolations" should { - val code = "%(#{1}#{2})" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %( - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | } - | )""".stripMargin - } - } - - "used as the argument to a `puts` command" should { - - "be parsed as a statement" in { - val code = "puts %()" - printAst(_.statement(), code) shouldEqual - """ExpressionOrCommandStatement - | InvocationExpressionOrCommand - | SingleCommandOnlyInvocationWithoutParentheses - | SimpleMethodCommand - | MethodIdentifier - | puts - | ArgumentsWithoutParentheses - | Arguments - | ExpressionArgument - | PrimaryExpression - | QuotedStringExpressionPrimary - | ExpandedQuotedStringLiteral - | %( - | )""".stripMargin - } - } - } - - "A double-quoted string literal" when { - - "empty" should { - val code = "\"\"" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | """".stripMargin - } - - "separated by whitespace" should { - val code = "\"x\" \"y\"" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | x - | " - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | y - | """".stripMargin - } - } - - "separated by '\\\\n'" should { - val code = - """"x" \ - | "y" """.stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | x - | " - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | y - | """".stripMargin - } - } - } - - "containing text and a numeric literal interpolation" should { - val code = """"text=#{1}"""" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | InterpolatedStringExpression - | StringInterpolation - | " - | text= - | InterpolatedStringSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | """".stripMargin - } - } - - "containing two numeric literal interpolations" should { - val code = """"#{1}#{2}"""" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | InterpolatedStringExpression - | StringInterpolation - | " - | InterpolatedStringSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 1 - | } - | InterpolatedStringSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 2 - | } - | """".stripMargin - } - } - - "separated by '\\\\n'" should { - val code = """"x" \ - | "y" """.stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | x - | " - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | y - | """".stripMargin - } - - "separated by '\\\\n' and containing a numeric interpolation" should { - val code = """"#{10}" \ - | "is a number."""".stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """StringExpressionPrimary - | ConcatenatedStringExpression - | InterpolatedStringExpression - | StringInterpolation - | " - | InterpolatedStringSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 10 - | } - | " - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | is a number. - | """".stripMargin - } - } - } - } - - "An expanded `%x` external command literal" when { - - "empty" should { - val code = "%x//" - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedExternalCommandLiteral - | %x/ - | /""".stripMargin - } - } - - "containing text and a string literal interpolation" should { - val code = "%x{l#{'s'}}" - - "be parsed as primary expression" in { - printAst(_.primary(), code) shouldEqual - """QuotedStringExpressionPrimary - | ExpandedExternalCommandLiteral - | %x{ - | l - | DelimitedStringInterpolation - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | StringExpressionPrimary - | SimpleStringExpression - | SingleQuotedStringLiteral - | 's' - | } - | }""".stripMargin - } - } - } - - "A HERE_DOCs expression" when { - - "used to generate a single string" should { - val code = - """<<-SQL - |SELECT * FROM food - |WHERE healthy = true - |SQL - |""".stripMargin - - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """LiteralPrimary - | HereDocLiteral - | <<-SQL - |SELECT * FROM food - |WHERE healthy = true - |SQL""".stripMargin - } - - } - - "used to generate a single string parameter for a function call" should { - val code = - """foo(<<-SQL) - |SELECT * FROM food - |WHERE healthy = true - |SQL - |""".stripMargin - - // TODO: The rest of the HERE_DOC should probably be parsed somehow - "be parsed as a primary expression" in { - printAst(_.primary(), code) shouldEqual - """InvocationWithParenthesesPrimary - | MethodIdentifier - | foo - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | HereDocArgument - | <<-SQL - | )""".stripMargin - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/SymbolTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/SymbolTests.scala deleted file mode 100644 index a48bb91d8167..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/SymbolTests.scala +++ /dev/null @@ -1,132 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class SymbolTests extends RubyParserAbstractTest { - - "Symbol literals" should { - - "be parsed as primary expressions" when { - - def symbolLiteralParseTreeText(symbolName: String): String = - s"""LiteralPrimary - | SymbolLiteral - | Symbol - | $symbolName""".stripMargin - - "they are named after keywords" in { - val eg = Seq( - ":__LINE__", - ":__ENCODING__", - ":__FILE__", - ":BEGIN", - ":END", - ":alias", - ":begin", - ":break", - ":case", - ":class", - ":def", - ":defined?", - ":do", - ":else", - ":elsif", - ":end", - ":ensure", - ":for", - ":false", - ":if", - ":in", - ":module", - ":next", - ":nil", - ":not", - ":or", - ":redo", - ":rescue", - ":retry", - ":self", - ":super", - ":then", - ":true", - ":undef", - ":unless", - ":until", - ":when", - ":while", - ":yield" - ) - eg.map(code => printAst(_.primary(), code)) shouldEqual eg.map(symbolLiteralParseTreeText) - } - - "they are named after operators" in { - val eg = Seq( - ":^", - ":&", - ":|", - ":<=>", - ":==", - ":===", - ":=~", - ":>", - ":>=", - ":<", - ":<=", - ":<<", - ":>>", - ":+", - ":-", - ":*", - ":/", - ":%", - ":**", - ":~", - ":+@", - ":-@", - ":[]", - ":[]=" - ) - eg.map(code => printAst(_.primary(), code)) shouldEqual eg.map(symbolLiteralParseTreeText) - } - - "they are given by a non-interpolated double-quoted string literal" in { - val code = """:"x y z"""" - printAst(_.primary(), code) shouldEqual - """LiteralPrimary - | SymbolLiteral - | Symbol - | : - | SimpleStringExpression - | DoubleQuotedStringLiteral - | " - | x y z - | """".stripMargin - } - - "they are given by an interpolated double-quoted string literal" in { - val code = """:"#{10}"""" - printAst(_.primary(), code) shouldEqual - """LiteralPrimary - | SymbolLiteral - | Symbol - | : - | InterpolatedStringExpression - | StringInterpolation - | " - | InterpolatedStringSequence - | #{ - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | LiteralPrimary - | NumericLiteralLiteral - | NumericLiteral - | UnsignedNumericLiteral - | 10 - | } - | """".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/TernaryConditionalTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/TernaryConditionalTests.scala deleted file mode 100644 index af86bb1d5f5d..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/TernaryConditionalTests.scala +++ /dev/null @@ -1,62 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class TernaryConditionalTests extends RubyParserAbstractTest { - - "Ternary conditional expressions" should { - - "be parsed as expressions" when { - - "they are a standalone one-line expression" in { - val code = "x ? y : z" - printAst(_.expression(), code) shouldEqual - """ConditionalOperatorExpression - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | ? - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | : - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | z""".stripMargin - - } - - "they are a standalone multi-line expression" in { - val code = - """x ? - | y - |: z - |""".stripMargin - printAst(_.expression(), code) shouldEqual - """ConditionalOperatorExpression - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | x - | ? - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | y - | : - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | z""".stripMargin - } - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/UnlessConditionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/UnlessConditionTests.scala deleted file mode 100644 index bc307bf80af3..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/parser/UnlessConditionTests.scala +++ /dev/null @@ -1,136 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.parser - -class UnlessConditionTests extends RubyParserAbstractTest { - - "An unless expression" should { - "be parsed as a primary expression" when { - - "it uses a newline instead of the keyword then" in { - val code = - """unless foo - | bar - |end - |""".stripMargin - - printAst(_.primary(), code) shouldEqual - """UnlessExpressionPrimary - | UnlessExpression - | unless - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | ThenClause - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | bar - | end""".stripMargin - } - - "it uses a semicolon instead of the keyword then" in { - val code = - """unless foo; bar - |end - |""".stripMargin - - printAst(_.primary(), code) shouldEqual - """UnlessExpressionPrimary - | UnlessExpression - | unless - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | ThenClause - | ; - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | bar - | end""".stripMargin - } - - "it uses the keyword then" in { - val code = - """unless foo then - | bar - |end - |""".stripMargin - - printAst(_.primary(), code) shouldEqual - """UnlessExpressionPrimary - | UnlessExpression - | unless - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | foo - | ThenClause - | then - | CompoundStatement - | Statements - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | bar - | end""".stripMargin - } - } - } - - "An unless (modifier) statement" should { - - "be parsed as a statement" when { - - "it explicitly returns an identifier out of a method" in { - val code = "return(value) unless item" - - printAst(_.statement(), code) shouldEqual - """ModifierStatement - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | ReturnWithParenthesesPrimary - | return - | ArgsOnlyArgumentsWithParentheses - | ( - | Arguments - | ExpressionArgument - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | value - | ) - | unless - | ExpressionOrCommandStatement - | ExpressionExpressionOrCommand - | PrimaryExpression - | VariableReferencePrimary - | VariableIdentifierVariableReference - | VariableIdentifier - | item""".stripMargin - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ConfigFileCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ConfigFileCreationPassTest.scala deleted file mode 100644 index ecd048bbb4a4..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ConfigFileCreationPassTest.scala +++ /dev/null @@ -1,62 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import better.files.File -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.passes.frontend.MetaDataPass -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ -import org.scalatest.matchers.should.Matchers -import org.scalatest.wordspec.AnyWordSpec - -class ConfigFileCreationPassTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "ConfigFileCreationPass for Gemfile files" should { - - "generate a ConfigFile accordingly" in { - val gemFileContents = - """ - |source 'https://rubygems.org' - |gem 'json' - |""".stripMargin - val cpg = code(gemFileContents, "Gemfile") - - val List(configFile) = cpg.configFile.l - configFile.name shouldBe "Gemfile" - configFile.content shouldBe gemFileContents - } - - "ignore non-root Gemfile files" in { - val cpg = code("# ignore me", Seq("subdir", "Gemfile").mkString(java.io.File.pathSeparator)) - cpg.configFile.size shouldBe 0 - } - } - - "ConfigFileCreationPass for Gemfile.lock files" should { - - "generate a ConfigFile accordingly" in { - val gemFileContents = - """ - |GEM - | remote: https://rubygems.org/ - | specs: - | CFPropertyList (3.0.1) - | - |PLATFORMS - | ruby - | - |BUNDLED WITH - | 2.1.4 - |""".stripMargin - val cpg = code(gemFileContents, "Gemfile.lock") - val List(configFile) = cpg.configFile.l - configFile.name shouldBe "Gemfile.lock" - configFile.content shouldBe gemFileContents - } - - "ignore non-root Gemfile.lock files" in { - val cpg = code("# ignore me", Seq("subdir", "Gemfile.lock").mkString(java.io.File.pathSeparator)) - cpg.configFile.size shouldBe 0 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/MetaDataPassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/MetaDataPassTests.scala deleted file mode 100644 index eeeb970917c1..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/MetaDataPassTests.scala +++ /dev/null @@ -1,22 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import better.files.File -import io.joern.rubysrc2cpg.{Config, RubySrc2Cpg} -import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language.* -import org.scalatest.matchers.should.Matchers -import org.scalatest.wordspec.AnyWordSpec - -class MetaDataPassTests extends AnyWordSpec with Matchers { - - "MetaDataPass" should { - - "create a metadata node with correct language" in { - File.usingTemporaryDirectory("rubysrc2cpgTest") { dir => - val config = Config().withInputPath(dir.pathAsString).withOutputPath(dir.pathAsString) - val cpg = new RubySrc2Cpg().createCpg(config).get - cpg.metaData.language.l shouldBe List(Languages.RUBYSRC) - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryTests.scala deleted file mode 100644 index 12a034b4ae94..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/RubyTypeRecoveryTests.scala +++ /dev/null @@ -1,258 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.rubysrc2cpg.deprecated.utils.PackageTable -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language.importresolver.* -import io.shiftleft.semanticcpg.language.* - -import scala.collection.immutable.List - -object RubyTypeRecoveryTests { - def getPackageTable: PackageTable = { - val packageTable = PackageTable() - packageTable.addTypeDecl("sendgrid-ruby", "API", "SendGrid.API") - packageTable.addModule("dbi", "DBI", "DBI") - packageTable.addTypeDecl("logger", "Logger", "Logger") - packageTable.addModule("stripe", "Customer", "Stripe.Customer") - packageTable - } - -} -class RubyTypeRecoveryTests - extends RubyCode2CpgFixture( - withPostProcessing = true, - packageTable = Some(RubyTypeRecoveryTests.getPackageTable), - useDeprecatedFrontend = true - ) { - - "Type information for nodes with external dependency" should { - - val cpg = code( - """ - |require "sendgrid-ruby" - | - |def func - | sg = SendGrid::API.new(api_key: ENV['SENDGRID_API_KEY']) - | response = sg.client.mail._('send').post(request_body: data) - |end - |""".stripMargin, - "main.rb" - ) - - "be present in (Case 1)" ignore { - cpg.identifier("sg").lineNumber(5).typeFullName.l shouldBe List("sendgrid-ruby::program.SendGrid.API") - cpg.call("client").dispatchType.l shouldBe List(DispatchTypes.DYNAMIC_DISPATCH) - cpg.call("client").methodFullName.l shouldBe List("sendgrid-ruby::program.SendGrid.API.client") - } - - "be present in (Case 2)" ignore { - cpg.call("post").methodFullName.l shouldBe List( - "sendgrid-ruby::program.SendGrid.API.client.mail.anonymous.post" - ) - } - } - - "literals declared from built-in types" should { - val cpg = code( - """ - |x = 123 - | - |def newfunc - | x = "foo" - |end - |module MyNamespace - | MY_CONSTANT = 42 - |end - |""".stripMargin, - "main.rb" - ) - "resolve 'x' identifier types despite shadowing" in { - val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l - xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") - xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") - } - - "resolve module constant type" in { - cpg.typeDecl("MyNamespace").size shouldBe 1 - val List(typeDecl) = cpg.typeDecl("MyNamespace").l - val List(myconst) = typeDecl.member.l - myconst.typeFullName shouldBe "__builtin.Integer" - } - } - - "recovering paths for built-in calls" should { - lazy val cpg = code( - """ - |print("Hello world") - |puts "Hello" - | - |def sleep(input) - |end - | - |sleep(2) - |""".stripMargin, - "main.rb" - ).cpg - - "resolve 'print' and 'puts' calls" in { - val List(printCall) = cpg.call("print").l - printCall.methodFullName shouldBe "__builtin.print" - val List(maxCall) = cpg.call("puts").l - maxCall.methodFullName shouldBe "__builtin.puts" - } - - "present the declared method name when a built-in with the same name is used in the same compilation unit" in { - val List(absCall) = cpg.call("sleep").l - absCall.methodFullName shouldBe "main.rb::program.sleep" - } - } - - "recovering module members across modules" should { - lazy val cpg = code( - """ - |require "dbi" - | - |module FooModule - | x = 1 - | y = "test" - | db = DBI.connect("DBI:Mysql:TESTDB:localhost", "testuser", "test123") - |end - | - |""".stripMargin, - "foo.rb" - ).moreCode( - """ - |require_relative "./foo.rb" - | - |z = FooModule::x - |z = FooModule::y - | - |d = FooModule::db - | - |row = d.select_one("SELECT VERSION()") - | - |""".stripMargin, - "bar.rb" - ).cpg - - // TODO Waiting for Module modelling to be done - "resolve correct imports via tag nodes" ignore { - val List(foo: ResolvedTypeDecl) = - cpg.file(".*foo.rb").ast.isCall.where(_.referencedImports).tag._toEvaluatedImport.toList: @unchecked - foo.fullName shouldBe "dbi::program.DBI" - val List(bar: ResolvedTypeDecl) = - cpg.file(".*bar.rb").ast.isCall.where(_.referencedImports).tag._toEvaluatedImport.toList: @unchecked - bar.fullName shouldBe "foo.rb::program.FooModule" - } - - "resolve 'x' and 'y' locally under foo.rb" in { - val Some(x) = cpg.identifier("x").where(_.file.name(".*foo.*")).headOption: @unchecked - x.typeFullName shouldBe "__builtin.Integer" - val Some(y) = cpg.identifier("y").where(_.file.name(".*foo.*")).headOption: @unchecked - y.typeFullName shouldBe "__builtin.String" - } - - "resolve 'FooModule.x' and 'FooModule.y' field access primitive types correctly" ignore { - val List(z1, z2) = cpg.file - .name(".*bar.*") - .ast - .isIdentifier - .name("z") - .l - z1.typeFullName shouldBe "ANY" - z1.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") - z2.typeFullName shouldBe "ANY" - z2.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") - } - - "resolve 'FooModule.d' field access object types correctly" ignore { - val Some(d) = cpg.file - .name(".*bar.*") - .ast - .isIdentifier - .name("d") - .headOption: @unchecked - d.typeFullName shouldBe "dbi::program.DBI.connect." - d.dynamicTypeHintFullName shouldBe Seq() - } - - "resolve a 'select_one' call indirectly from 'FooModule.d' field access correctly" ignore { - val List(d) = cpg.file - .name(".*bar.*") - .ast - .isCall - .name("select_one") - .l - d.methodFullName shouldBe "dbi::program.DBI.connect..select_one" - d.dynamicTypeHintFullName shouldBe Seq() - d.callee(NoResolve).isExternal.headOption shouldBe Some(true) - } - - } - - "assignment from a call to a identifier inside an imported module using new" should { - lazy val cpg = code(""" - |require 'logger' - | - |log = Logger.new(STDOUT) - |log.error("foo") - | - |""".stripMargin).cpg - - "resolve correct imports via tag nodes" in { - val List(logging: ResolvedMethod, _) = - cpg.call.where(_.referencedImports).tag._toEvaluatedImport.toList: @unchecked - logging.fullName shouldBe s"logger::program.Logger.${XDefines.ConstructorMethodName}" - } - - "provide a dummy type" ignore { - val List(error) = cpg.call("error").l: @unchecked - error.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH - val Some(log) = cpg.identifier("log").headOption: @unchecked - log.typeFullName shouldBe "logger::program.Logger" - val List(errorCall) = cpg.call("error").l - errorCall.methodFullName shouldBe "logger::program.Logger.error" - } - } - - "assignment from a call to a identifier inside an imported module using methodCall" should { - lazy val cpg = code(""" - |require 'stripe' - | - |customer = Stripe::Customer.create - | - |""".stripMargin).cpg - - "resolved the type of call" in { - val Some(create) = cpg.call("create").headOption: @unchecked - create.methodFullName shouldBe "stripe::program.Stripe.Customer.create" - } - - "resolved the type of identifier" in { - val Some(customer) = cpg.identifier("customer").headOption: @unchecked - customer.typeFullName shouldBe "stripe::program.Stripe.Customer.create." - } - } - - "recovery of type for call having a method with same name" should { - lazy val cpg = code(""" - |require "dbi" - | - |def connect - | puts "I am here" - |end - | - |d = DBI.connect("DBI:Mysql:TESTDB:localhost", "testuser", "test123") - |""".stripMargin) - - "have a correct type for call `connect`" in { - cpg.call("connect").methodFullName.l shouldBe List("dbi::program.DBI.connect") - } - - "have a correct type for identifier `d`" in { - cpg.identifier("d").typeFullName.l shouldBe List("dbi::program.DBI.connect.") - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/UnknownConstructPass.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/UnknownConstructPass.scala deleted file mode 100644 index 84a233ed8b2b..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/UnknownConstructPass.scala +++ /dev/null @@ -1,89 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.utils.Environment -import io.joern.x2cpg.utils.Environment.OperatingSystemType -import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.semanticcpg.language.* - -class UnknownConstructPass extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "invalid assignment" ignore { - val cpg = code(""" - |a = 1 - |b = [[] - |c = 2 - |""".stripMargin) - - "be ignored" in { - val List(a, c) = cpg.assignment.l - a.target.code shouldBe "a" - a.source.code shouldBe "1" - - c.target.code shouldBe "c" - c.source.code shouldBe "2" - } - } - - "invalid method body" ignore { - val cpg = code(""" - |x = 1 - |def random(a) - | b = 3 + 2) - | y = 2 - |end - |z = 2 - |""".stripMargin) - - "preserve code around it and show the rest of the method body" in { - val List(x, y, z, random) = cpg.assignment.l - x.target.code shouldBe "x" - x.source.code shouldBe "1" - - y.target.code shouldBe "y" - y.source.code shouldBe "2" - - z.target.code shouldBe "z" - z.source.code shouldBe "2" - - random.target.code shouldBe "random" - random.source.code shouldBe "def random(...)" - - val List(m: Method) = cpg.method.nameExact("random").l - val List(_y) = m.assignment.l - y.id() shouldBe _y.id() - } - } - - "unrecognized token in the RHS of an assignment" ignore { - val cpg = code(""" - |x = \! - |y = 1 - |""".stripMargin) - - "be ignored" in { - val List(y) = cpg.assignment.l - - y.target.code shouldBe "y" - y.source.code shouldBe "1" - } - } - - "an attempted fix" ignore { - val cpg = code(""" - |class DerivedClass < BaseClass - | KEYS = %w( - | id1 - | id2 - | id3 - | ).freeze - |end - |""".stripMargin) - - "not cause an infinite loop once the last line is blanked out, at the cost of the structure (in Unix)" in { - cpg.typeDecl("DerivedClass").size shouldBe - (if (Environment.operatingSystem == OperatingSystemType.Windows) 1 else 0) - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AssignCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AssignCpgTests.scala deleted file mode 100644 index 95d098a62ae8..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AssignCpgTests.scala +++ /dev/null @@ -1,193 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{DifferentInNewFrontend, RubyCode2CpgFixture, SameInNewFrontend} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} -import io.shiftleft.semanticcpg.language.* -import org.scalatest.Tag - -class AssignCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "single target assign" should { - val cpg = code("""x = 2""".stripMargin) - - "test local and identifier nodes" taggedAs SameInNewFrontend in { - val localX = cpg.local.head - localX.name shouldBe "x" - val List(idX) = localX.referencingIdentifiers.l: @unchecked - idX.name shouldBe "x" - } - - "test assignment node properties" taggedAs SameInNewFrontend in { - val assignCall = cpg.call.methodFullName(Operators.assignment).head - assignCall.code shouldBe "x = 2" - assignCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - assignCall.lineNumber shouldBe Some(1) - assignCall.columnNumber shouldBe Some(2) - } - - "test assignment node ast children" taggedAs SameInNewFrontend in { - cpg.call - .methodFullName(Operators.assignment) - .astChildren - .order(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.assignment) - .astChildren - .order(2) - .isLiteral - .head - .code shouldBe "2" - } - - "test assignment node arguments" taggedAs SameInNewFrontend in { - cpg.call - .methodFullName(Operators.assignment) - .argument - .argumentIndex(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.assignment) - .argument - .argumentIndex(2) - .isLiteral - .head - .code shouldBe "2" - } - } - - "nested decomposing assign" should { - val cpg = code("""x, (y, z) = [1, [2, 3]]""".stripMargin) - - def getSurroundingBlock: nodes.Block = { - cpg.all.collect { case block: nodes.Block if block.code != "" => block }.head - } - - "test block exists" in { - // Throws if block does not exist. - getSurroundingBlock - } - - // TODO: .code property need to be fixed - "test block node properties" ignore { - val block = getSurroundingBlock - block.code shouldBe - """tmp0 = list - |x = tmp0[0] - |y = tmp0[1][0] - |z = tmp0[1][1]""".stripMargin - block.lineNumber shouldBe Some(1) - } - - // TODO: Need to fix the local variables - "test local node" ignore { - cpg.method.name("Test0.rb::program").local.name("tmp0").headOption should not be empty - } - - "test tmp variable assignment" in { - val block = getSurroundingBlock - val tmpAssignNode = block.astChildren.isCall.sortBy(_.order).head - // tmpAssignNode.code shouldBe "tmp0 = list" - tmpAssignNode.methodFullName shouldBe Operators.assignment - tmpAssignNode.lineNumber shouldBe Some(1) - } - - // TODO: Fix the code property of the Block node & the order too - "test assignments to targets" ignore { - val block = getSurroundingBlock - val assignNodes = block.astChildren.isCall.sortBy(_.order).tail - assignNodes.map(_.code) should contain theSameElementsInOrderAs List( - "x = tmp0[0]", - "y = tmp0[1][0]", - "z = tmp0[1][1]" - ) - assignNodes.map(_.lineNumber.get) should contain theSameElementsInOrderAs List(1, 1, 1) - } - - } - - "array destructuring assign" should { - val cpg = code("""x, *, y = [1, 2, 3, 5]""".stripMargin) - - def getSurroundingBlock: nodes.Block = { - cpg.all.collect { case block: nodes.Block if block.code != "" => block }.head - } - - "test block exists" in { - // Throws if block does not exist. - getSurroundingBlock - } - - // TODO: .code property need to be fixed - "test block node properties" ignore { - val block = getSurroundingBlock - block.code shouldBe - """tmp0 = list - |x = tmp0[0] - |y = tmp0[1][0] - |z = tmp0[1][1]""".stripMargin - block.astChildren.length shouldBe 4 - cpg.identifier("x").isEmpty shouldBe false - cpg.identifier("y").isEmpty shouldBe false - cpg.identifier("z").isEmpty shouldBe false - - } - - // TODO: Need to fix the local variables - "test local node" ignore { - cpg.method.name("Test0.rb::program").local.name("tmp0").headOption should not be empty - } - - } - - "multi target assign" should { - val cpg = code("""x = y = "abcd"""".stripMargin) - - def getSurroundingBlock: nodes.Block = { - cpg.all.collect { case block: nodes.Block if block.code != "" => block }.head - } - - "test block exists" taggedAs DifferentInNewFrontend in { - // Throws if block does not exist. - getSurroundingBlock - } - - // TODO: Fix the code property of the Block node - "test block node properties" taggedAs DifferentInNewFrontend ignore { - val block = getSurroundingBlock - block.code shouldBe - """tmp0 = list - |x = tmp0 - |y = tmp0""".stripMargin - block.lineNumber shouldBe Some(1) - } - - // TODO: Need to fix the local variables - "test local node" taggedAs DifferentInNewFrontend ignore { - cpg.method.name("Test0.rb::program").local.name("tmp0").headOption should not be empty - } - - // TODO: Need to fix the code property - "test tmp variable assignment" taggedAs DifferentInNewFrontend ignore { - val block = getSurroundingBlock - val tmpAssignNode = block.astChildren.isCall.sortBy(_.order).head - tmpAssignNode.code shouldBe "tmp0 = list" - tmpAssignNode.methodFullName shouldBe Operators.assignment - tmpAssignNode.lineNumber shouldBe Some(1) - } - } - - "empty array assignment" should { - val cpg = code("""x.y = []""".stripMargin) - - "have an empty assignment" taggedAs DifferentInNewFrontend in { - val List(assignment) = cpg.call.name(Operators.assignment).l - assignment.argument.where(_.argumentIndex(2)).isCall.name.l shouldBe List(Operators.arrayInitializer) - assignment.argument.where(_.argumentIndex(2)).isCall.argument.l shouldBe List() - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AttributeCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AttributeCpgTests.scala deleted file mode 100644 index 847ce13f137e..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/AttributeCpgTests.scala +++ /dev/null @@ -1,53 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{RubyCode2CpgFixture, SameInNewFrontend} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language.* - -class AttributeCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code("""x.y""".stripMargin) - - // TODO: Class Modeling testcase - "test field access call node properties" taggedAs SameInNewFrontend ignore { - val callNode = cpg.call.methodFullName(Operators.fieldAccess).head - callNode.code shouldBe "x.y" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(1) - } - - // TODO: Class Modeling testcase - "test field access call ast children" ignore { - cpg.call - .methodFullName(Operators.fieldAccess) - .astChildren - .order(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.fieldAccess) - .astChildren - .order(2) - .isFieldIdentifier - .head - .code shouldBe "y" - } - - // TODO: Class Modeling testcase - "test field access call arguments" ignore { - cpg.call - .methodFullName(Operators.fieldAccess) - .argument - .argumentIndex(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.fieldAccess) - .argument - .argumentIndex(2) - .isFieldIdentifier - .head - .code shouldBe "y" - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BinOpCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BinOpCpgTests.scala deleted file mode 100644 index 6b28930d82c2..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BinOpCpgTests.scala +++ /dev/null @@ -1,53 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes, DispatchTypes, Operators, nodes} -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal - -class BinOpCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code("""1 + 2""".stripMargin) - - "test binOp 'add' call node properties" in { - val additionCall = cpg.call.methodFullName(Operators.addition).head - additionCall.code shouldBe "1 + 2" - additionCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - additionCall.lineNumber shouldBe Some(1) - // TODO additionCall.columnNumber shouldBe Some(1) - } - - "test binOp 'add' ast children" in { - cpg.call - .methodFullName(Operators.addition) - .astChildren - .order(1) - .isLiteral - .head - .code shouldBe "1" - cpg.call - .methodFullName(Operators.addition) - .astChildren - .order(2) - .isLiteral - .head - .code shouldBe "2" - } - - "test binOp 'add' arguments" in { - cpg.call - .methodFullName(Operators.addition) - .argument - .argumentIndex(1) - .isLiteral - .head - .code shouldBe "1" - cpg.call - .methodFullName(Operators.addition) - .argument - .argumentIndex(2) - .isLiteral - .head - .code shouldBe "2" - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BoolOpCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BoolOpCpgTests.scala deleted file mode 100644 index e3e5690a2aed..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/BoolOpCpgTests.scala +++ /dev/null @@ -1,68 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{DifferentInNewFrontend, RubyCode2CpgFixture, SameInNewFrontend} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language.* - -class BoolOpCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code("""x or y or z""".stripMargin) - - "test boolOp 'or' call node properties" taggedAs SameInNewFrontend in { - val orCall = cpg.call.head -// val orCall = cpg.call.methodFullName(Operators.logicalOr).head - orCall.code shouldBe "x or y or z" - orCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - orCall.lineNumber shouldBe Some(1) - // TODO orCall.columnNumber shouldBe Some(3) - } - - // TODO: Fix this multi logicalOr operation - "test boolOp 'or' ast children" taggedAs DifferentInNewFrontend ignore { - cpg.call - .methodFullName(Operators.logicalOr) - .astChildren - .order(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.logicalOr) - .astChildren - .order(2) - .isIdentifier - .head - .code shouldBe "y" - cpg.call - .methodFullName(Operators.logicalOr) - .astChildren - .order(3) - .isIdentifier - .head - .code shouldBe "z" - } - - // TODO: Fix this multi logicalOr operation arguments - "test boolOp 'or' arguments" taggedAs DifferentInNewFrontend ignore { - cpg.call - .methodFullName(Operators.logicalOr) - .argument - .argumentIndex(1) - .isIdentifier - .head - .code shouldBe "x" - cpg.call - .methodFullName(Operators.logicalOr) - .argument - .argumentIndex(2) - .isIdentifier - .head - .code shouldBe "y" - cpg.call - .methodFullName(Operators.logicalOr) - .argument - .argumentIndex(3) - .isIdentifier - .head - .code shouldBe "z" - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CallCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CallCpgTests.scala deleted file mode 100644 index b2c499a6cd93..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CallCpgTests.scala +++ /dev/null @@ -1,279 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.testfixtures.{RubyCode2CpgFixture, SameInNewFrontend} -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, MethodRef} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, nodes} -import io.shiftleft.semanticcpg.language.* - -class CallCpgTests extends RubyCode2CpgFixture(withPostProcessing = true, useDeprecatedFrontend = true) { - "simple call method" should { - val cpg = code("""foo("a", b)""".stripMargin) - - "test call node properties" taggedAs SameInNewFrontend in { - val callNode = cpg.call.name("foo").head - callNode.code shouldBe """foo("a", b)""" - callNode.signature shouldBe "" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(1) - } - - "test call arguments" taggedAs SameInNewFrontend in { - val callNode = cpg.call.name("foo").head - val arg1 = callNode.argument(1) - arg1.code shouldBe "\"a\"" - - val arg2 = callNode.argument(2) - arg2.code shouldBe "b" - } - - "test astChildren" taggedAs SameInNewFrontend in { - val callNode = cpg.call.name("foo").head - val children = callNode.astChildren - children.size shouldBe 2 - - val firstChild = children.head - val secondChild = children.last - - firstChild.code shouldBe "\"a\"" - secondChild.code shouldBe "b" - } - } - - "call on identifier with named argument" should { - val cpg = code("""x.foo("a", b)""".stripMargin) - - "test call node properties" in { - val callNode = cpg.call.name("foo").head - callNode.code shouldBe """x.foo("a", b)""" - callNode.signature shouldBe "" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(1) - } - - "test call arguments" in { - val callNode = cpg.call.name("foo").head - val arg1 = callNode.argument(1) - arg1.code shouldBe "\"a\"" - - val arg2 = callNode.argument(2) - arg2.code shouldBe "b" - } - - "test astChildren" in { - val callNode = cpg.call.name("foo").head - val children = callNode.astChildren - children.size shouldBe 3 - - val firstChild = children.head - val lastChild = children.last - - firstChild.code shouldBe "x" - lastChild.code shouldBe "b" - } - } - - "call following a definition within the same module" should { - val cpg = code(""" - |def func(a, b) - | return a + b - |end - |x = func(a, b) - |""".stripMargin) - - "test call node properties" in { - val callNode = cpg.call.name("func").head - callNode.name shouldBe "func" - callNode.signature shouldBe "" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(5) - } - } - - "call with the splat operator" should { - val cpg = code(""" - |def print_list_of(**books_and_articles) - | books_and_articles.each do |book, article| - | puts book - | puts article - | end - |end - |# As an argument, we define a hash in which we will write books and articles. - |books_and_articles_we_love = { - | "Ruby on Rails 4": "What is webpack?", - | "Ruby essentials": "What is Ruby Object Model?", - | "Javascript essentials": "What is Object?" - |} - |print_list_of(books_and_articles_we_love) - |""".stripMargin) - - "test call node properties with children & argument" in { - val callNode = cpg.call.name("print_list_of").head - callNode.code shouldBe "print_list_of(books_and_articles_we_love)" - callNode.signature shouldBe "" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(14) - callNode.astChildren.last.code shouldBe "books_and_articles_we_love" - callNode.argument.last.code shouldBe "books_and_articles_we_love" - } - } - - "call with a heredoc parameter" should { - val cpg = code("""foo(<<~SQL) - |SELECT * FROM food - |WHERE healthy = true - |SQL - |""".stripMargin) - - "take note of the here doc location and construct a literal from the following statements" in { - val List(sql) = cpg.call.nameExact("foo").argument.isLiteral.l: @unchecked - sql.code shouldBe - """SELECT * FROM food - |WHERE healthy = true - |""".stripMargin.trim - sql.lineNumber shouldBe Option(1) - sql.columnNumber shouldBe Option(4) - sql.typeFullName shouldBe Defines.String - } - } - - // TODO: Handle multiple heredoc parameters - "call with multiple heredoc parameters" ignore { - val cpg = code("""puts(<<-ONE, <<-TWO) - |content for heredoc one - |ONE - |content for heredoc two - |TWO - |""".stripMargin) - - "take note of the here doc locations and construct the literals respectively from the following statements" in { - val List(one, two) = cpg.call.nameExact("puts").argument.isLiteral.l: @unchecked - one.code shouldBe "content for heredoc one" - one.lineNumber shouldBe Option(1) - one.columnNumber shouldBe Option(5) - one.typeFullName shouldBe Defines.String - two.code shouldBe "content for heredoc two" - two.lineNumber shouldBe Option(1) - two.columnNumber shouldBe Option(13) - two.typeFullName shouldBe Defines.String - } - } - - "a call with a normal and a do block argument" should { - val cpg = code(""" - |def client - | Faraday.new(API_HOST) do |builder| - | builder.request :json - | builder.options[:timeout] = READ_TIMEOUT - | builder.options[:open_timeout] = OPEN_TIMEOUT - | end - |end - |""".stripMargin) - - "have the correct arguments in the correct ordering" in { - val List(n) = cpg.call.nameExact("new").l: @unchecked - val List(faraday: Identifier, apiHost: Identifier, doRef: MethodRef) = n.argument.l: @unchecked - faraday.name shouldBe "Faraday" - faraday.argumentIndex shouldBe 0 - apiHost.name shouldBe "API_HOST" - apiHost.argumentIndex shouldBe 1 - doRef.methodFullName shouldBe "Test0.rb::program.new3" - doRef.argumentIndex shouldBe 2 - } - } - - "a call without parenthesis before the method definition is seen/resolved" should { - val cpg = code( - """ - |require "foo.rb" - | - |def event_params - | @event_params ||= device_params - | .merge(params) - | .merge(encoded_partner_params) - | .merge( - | s2s: 1, - | created_at_unix: Time.current.to_i, - | app_token: app_token, - | event_token: event_token, - | install_source: install_source - | ) - |end - |""".stripMargin, - "bar.rb" - ) - .moreCode( - """ - |def device_params - | case platform - | when :android - | { adid: adid, gps_adid: gps_adid } - | when :ios - | { adid: adid, idfa: idfa } - | else - | {} - | end - |end - |""".stripMargin, - "foo.rb" - ) - - "have its call node correctly identified and created" in { - val List(deviceParams) = cpg.call.nameExact("device_params").l: @unchecked - deviceParams.name shouldBe "device_params" - deviceParams.code shouldBe "device_params" - deviceParams.methodFullName shouldBe "foo.rb::program.device_params" - deviceParams.typeFullName shouldBe Defines.Any - deviceParams.lineNumber shouldBe Option(5) - deviceParams.columnNumber shouldBe Option(22) - deviceParams.argumentIndex shouldBe 0 - } - } - - "a parenthesis-less call (defined later in the module) in a call's argument" should { - val cpg = code(""" - |module Pay - | module Webhooks - | class BraintreeController < Pay::ApplicationController - | if Rails.application.config.action_controller.default_protect_from_forgery - | skip_before_action :verify_authenticity_token - | end - | - | def create - | queue_event(verified_event) # <------ verified event is a call here - | head :ok - | rescue ::Braintree::InvalidSignature - | head :bad_request - | end - | - | private - | - | def queue_event(event) - | return unless Pay::Webhooks.delegator.listening?("braintree.#{event.kind}") - | - | record = Pay::Webhook.create!( - | processor: :braintree, - | event_type: event.kind, - | event: {bt_signature: params[:bt_signature], bt_payload: params[:bt_payload]} - | ) - | Pay::Webhooks::ProcessJob.perform_later(record) - | end - | - | def verified_event - | Pay.braintree_gateway.webhook_notification.parse(params[:bt_signature], params[:bt_payload]) - | end - | end - | end - |end - |""".stripMargin) - - "be a call node instead of an identifier" in { - inside(cpg.call("queue_event").argument.l) { - case (verifiedEvent: Call) :: Nil => - verifiedEvent.name shouldBe "verified_event" - case xs => - fail(s"Expected a single call argument, received [${xs.map(x => x.label -> x.code).mkString(", ")}] instead!") - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CustomAssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CustomAssignmentTests.scala deleted file mode 100644 index 429a39c47961..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/CustomAssignmentTests.scala +++ /dev/null @@ -1,57 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{DifferentInNewFrontend, RubyCode2CpgFixture} -import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, MethodRef, TypeRef} -import io.shiftleft.semanticcpg.language.* - -class CustomAssignmentTests extends RubyCode2CpgFixture(withPostProcessing = true, useDeprecatedFrontend = true) { - - "custom assignment for builtIn" should { - val cpg = code(""" - |puts "This is ruby" - |""".stripMargin) - "be created for builtin presence" taggedAs DifferentInNewFrontend in { - val List(putsAssignmentCall, _) = cpg.call.l - putsAssignmentCall.name shouldBe ".assignment" - - val List(putsIdentifier: Identifier, putsBuiltInTypeRef: TypeRef) = putsAssignmentCall.argument.l: @unchecked - - putsIdentifier.name shouldBe "puts" - putsBuiltInTypeRef.code shouldBe "__builtin.puts" - putsBuiltInTypeRef.typeFullName shouldBe "__builtin.puts" - } - - "resolve type for `puts`" in { - val List(_, putsCall) = cpg.call.l - putsCall.name shouldBe "puts" - putsCall.methodFullName shouldBe "__builtin.puts" - } - } - - "custom assignment for user defined function" should { - val cpg = code(""" - |def foo() - | return "This is my foo" - |end - | - |foo() - |""".stripMargin) - "be created" in { - val List(fooAssignmentCall, _) = cpg.call.l - fooAssignmentCall.name shouldBe ".assignment" - - val List(fooIdentifier: Identifier, fooMethodRef: MethodRef) = fooAssignmentCall.argument.l: @unchecked - - fooIdentifier.name shouldBe "foo" - fooMethodRef.methodFullName shouldBe "Test0.rb::program.foo" - fooMethodRef.referencedMethod.name shouldBe "foo" - } - - "resolve type for `foo`" in { - val List(_, fooCall) = cpg.call.l - fooCall.name shouldBe "foo" - fooCall.methodFullName shouldBe "Test0.rb::program.foo" - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/DoBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/DoBlockTest.scala deleted file mode 100644 index 060d15d2d2a1..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/DoBlockTest.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class DoBlockTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "defining a method using metaprogramming and a do-block function" should { - val cpg = code(s""" - |define_method foo do |name, age| - | value = public_send("#{name}_value") - | unit = public_send("#{name}_unit") - | - | puts "My name is #{name} and age is #{age}" - | - | next unless value.present? && unit.present? - | value.public_send(unit) - |end - |""".stripMargin) - - "create a do-block method called `foo`" in { - val nameMethod :: _ = cpg.method.nameExact("foo").l: @unchecked - - val List(name, age) = nameMethod.parameter.l - name.name shouldBe "name" - age.name shouldBe "age" - - val List(value, unit) = nameMethod.local.l - value.name shouldBe "value" - unit.name shouldBe "unit" - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FileTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FileTests.scala deleted file mode 100644 index d10d44769cdd..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FileTests.scala +++ /dev/null @@ -1,62 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.language.types.structure.FileTraversal -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal - -class FileTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code(""" - |def foo() - |end - |def bar() - |end - |class MyClass - |end - |""".stripMargin) - - // TODO: Fix this unit test - "should contain two file nodes in total, both with order=0" ignore { - cpg.file.order.l shouldBe List(0, 0) - cpg.file.name(FileTraversal.UNKNOWN).size shouldBe 1 - cpg.file.nameNot(FileTraversal.UNKNOWN).size shouldBe 1 - } - - "should contain exactly one placeholder file node with `name=\"\"/order=0`" in { - cpg.file(FileTraversal.UNKNOWN).order.l shouldBe List(0) - cpg.file(FileTraversal.UNKNOWN).hash.l shouldBe List() - } - - "should allow traversing from file to its namespace blocks" in { - cpg.file.nameNot(FileTraversal.UNKNOWN).namespaceBlock.name.toSetMutable shouldBe Set( - NamespaceTraversal.globalNamespaceName - ) - } - - "should allow traversing from file to its methods via namespace block" in { - cpg.file.nameNot(FileTraversal.UNKNOWN).method.name.toSetMutable shouldBe Set("foo", "bar", "", ":program") - } - - // TODO: TypeDecl fix this unit test - "should allow traversing from file to its type declarations via namespace block" ignore { - cpg.file - .nameNot(FileTraversal.UNKNOWN) - .typeDecl - .nameNot(NamespaceTraversal.globalNamespaceName) - .name - .l - .sorted shouldBe List("MyClass") - } - - // TODO: Need to fix this test. - "should allow traversing to namespaces" ignore { - val List(ns1, ns2) = cpg.file.namespaceBlock.l - // At present it returning full file system path. It should return relative path - ns1.filename shouldBe "Test0.rb" - // At present it returning full file system path. It should return relative path - ns1.fullName shouldBe "Test0.rb:" - ns2.filename shouldBe "" - ns2.fullName shouldBe "" - cpg.file.namespace.name(NamespaceTraversal.globalNamespaceName).l.size shouldBe 2 - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FormatStringCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FormatStringCpgTests.scala deleted file mode 100644 index 3b58d7b5ec54..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/FormatStringCpgTests.scala +++ /dev/null @@ -1,59 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes, DispatchTypes, Operators, nodes} -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal - -class FormatStringCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - "#string interpolation" should { - val cpg = code("""puts "pre#{x}post"""".stripMargin) - "test formatValue operator node" in { - val callNode = cpg.call.methodFullName(".formatValue").head - callNode.code shouldBe "#{x}" - callNode.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - callNode.lineNumber shouldBe Some(1) - } - - "test formatString operator node arguments" ignore { - val callNode = cpg.call.methodFullName(".formatValue").head - - val child1 = callNode.astChildren.order(1).isLiteral.head - child1.code shouldBe "pre" - child1.argumentIndex shouldBe 1 - - val child2 = callNode.astChildren.order(2).isCall.head - child2.code shouldBe "#{x}" - child2.argumentIndex shouldBe 2 - child2.methodFullName shouldBe ".formatValue" - child2.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - - val child3 = callNode.astChildren.order(3).isLiteral.head - child3.code shouldBe "post" - child3.argumentIndex shouldBe 3 - } - - "test formattedValue operator child" ignore { - val callNode = cpg.call.methodFullName(".formatValue").head - - val child1 = callNode.astChildren.order(1).isIdentifier.head - child1.code shouldBe "x" - child1.argumentIndex shouldBe 1 - } - } - - "test format string with multiple replacement fields" in { - val cpg = code("""puts "The number #{a} is less than #{b}"""".stripMargin) - val callNodeA = cpg.call.methodFullName(".formatValue").head - val callNodeB = cpg.call.methodFullName(".formatValue").last - callNodeA.code shouldBe "#{a}" - callNodeB.code shouldBe "#{b}" - } - - "test format string with only single replacement field" in { - val cpg = code("""puts "#{a}"""".stripMargin) - val callNode = cpg.call.methodFullName(".formatValue").head - callNode.code shouldBe "#{a}" - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/IdentifierLocalTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/IdentifierLocalTests.scala deleted file mode 100644 index af1d8a6f45e4..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/IdentifierLocalTests.scala +++ /dev/null @@ -1,85 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language._ - -class IdentifierLocalTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - val cpg = code(""" - |def method1() - | x = 1 - | x = 2 - |end - | - |def method2(x) - | x = 2 - |end - | - |def method3(x) - | y = 0 - | - | if true - | innerx = 0 - | innery = 0 - | - | innerx = 1 - | innery = 1 - | end - | - | x = 1 - | y = 1 - |end - | - |""".stripMargin) - - // TODO: Need to be fixed. - "be correct for local x in method1" ignore { - val List(method) = cpg.method.nameExact("method1").l - method.block.ast.isIdentifier.l.size shouldBe 2 - val List(identifierX, _) = method.block.ast.isIdentifier.l - identifierX.name shouldBe "x" - - val localX = identifierX._localViaRefOut.get - localX.name shouldBe "x" - } - - "be correct for parameter x in method2" in { - val List(method) = cpg.method.nameExact("method2").l - val List(identifierX) = method.block.ast.isIdentifier.l - identifierX.name shouldBe "x" - - identifierX.refsTo.l.size shouldBe 1 - val List(paramx) = identifierX.refsTo.l - paramx.name shouldBe "x" - - val parameterX = identifierX._methodParameterInViaRefOut.get - parameterX.name shouldBe "x" - } - - "Reach parameter from last identifier" in { - val List(method) = cpg.method.nameExact("method3").l - val List(outerIdentifierX) = method.ast.isIdentifier.lineNumber(22).l - val parameterX = outerIdentifierX._methodParameterInViaRefOut.get - parameterX.name shouldBe "x" - } - - // TODO: Need to be fixed. - "inner block test" ignore { - val List(method) = cpg.method.nameExact("method3").l - method.block.astChildren.isBlock.l.size shouldBe 1 - val List(nestedBlock) = method.block.astChildren.isBlock.l - nestedBlock.ast.isIdentifier.nameExact("innerx").l.size shouldBe 2 - } - - // TODO: Need to be fixed. - "nested block identifier to local traversal" ignore { - val List(method) = cpg.method.nameExact("method3").l - method.block.astChildren.isBlock.l.size shouldBe 1 - val List(nestedBlock) = method.block.astChildren.isBlock.l - nestedBlock.ast.isIdentifier.nameExact("innerx").l.size shouldBe 2 - val List(nestedIdentifierX, _) = nestedBlock.ast.isIdentifier.nameExact("innerx").l - - val nestedLocalX = nestedIdentifierX._localViaRefOut.get - nestedLocalX.name shouldBe "innerx" - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ImportAstCreationTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ImportAstCreationTest.scala deleted file mode 100644 index 3149bcdbf85d..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ImportAstCreationTest.scala +++ /dev/null @@ -1,33 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class ImportAstCreationTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "Ast creation for import node" should { - val cpg = code(""" - |require "dummy_logger" - |require_relative "util/help.rb" - |load "mymodule.rb" - |""".stripMargin) - val imports = cpg.imports.l - val calls = cpg.call("require|require_relative|load").l - "have a valid import node" in { - imports.importedEntity.l shouldBe List("dummy_logger", "util/help.rb", "mymodule.rb") - imports.importedAs.l shouldBe List("dummy_logger", "util/help.rb", "mymodule.rb") - } - - "have a valid call node" in { - calls.code.l shouldBe List( - "require \"dummy_logger\"", - "require_relative \"util/help.rb\"", - "load \"mymodule.rb\"" - ) - } - - "have a valid linking" in { - calls.referencedImports.importedEntity.l shouldBe List("dummy_logger", "util/help.rb", "mymodule.rb") - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/LiteralCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/LiteralCpgTests.scala deleted file mode 100644 index 35b1b4a561c8..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/LiteralCpgTests.scala +++ /dev/null @@ -1,27 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* -class LiteralCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "A here doc string literal" should { - val cpg = code("""<<-SQL - |SELECT * FROM food - |WHERE healthy = true - |SQL - |""".stripMargin) - - "be interpreted as a single literal string" in { - val List(sql) = cpg.literal.l: @unchecked - sql.code shouldBe - """SELECT * FROM food - |WHERE healthy = true - |""".stripMargin.trim - sql.lineNumber shouldBe Option(1) - sql.columnNumber shouldBe Option(0) - sql.typeFullName shouldBe Defines.String - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MetaDataTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MetaDataTests.scala deleted file mode 100644 index 4ff668b72a19..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MetaDataTests.scala +++ /dev/null @@ -1,28 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} -import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ -class MetaDataTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code("""puts 123""") - - "should contain exactly one node with all mandatory fields set" in { - val List(x) = cpg.metaData.l - x.language shouldBe Languages.RUBYSRC - x.version shouldBe "0.1" - x.overlays shouldBe List( - Base.overlayName, - ControlFlow.overlayName, - TypeRelations.overlayName, - CallGraph.overlayName - ) - x.hash shouldBe None - } - - "should not have any incoming or outgoing edges" in { - cpg.metaData.size shouldBe 1 - cpg.metaData.in.l shouldBe List() - cpg.metaData.out.l shouldBe List() - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodOneTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodOneTests.scala deleted file mode 100644 index ad3c1d9d3fb9..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodOneTests.scala +++ /dev/null @@ -1,185 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{DifferentInNewFrontend, RubyCode2CpgFixture, SameInNewFrontend} -import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language.* - -class MethodOneTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "Method test with regular keyword def and end " should { - val cpg = code(""" - |def foo(a, b) - | return "" - |end - |""".stripMargin) - - "should contain exactly one method node with correct fields" in { - inside(cpg.method.name("foo").l) { case List(x) => - x.name shouldBe "foo" - x.isExternal shouldBe false - x.fullName shouldBe "Test0.rb::program.foo" - x.code should startWith("def foo(a, b)") - x.isExternal shouldBe false - x.order shouldBe 3 - x.filename.endsWith("Test0.rb") - x.lineNumber shouldBe Option(2) - x.lineNumberEnd shouldBe Option(4) - } - } - - "should return correct number of lines" taggedAs SameInNewFrontend in { - cpg.method.name("foo").numberOfLines.l shouldBe List(3) - } - - "should allow traversing to parameters" in { - cpg.method.name("foo").parameter.name.toSetMutable shouldBe Set("a", "b") - } - - "should allow traversing to methodReturn" ignore { - cpg.method.name("foo").methodReturn.l.size shouldBe 1 - cpg.method.name("foo").methodReturn.typeFullName.head shouldBe "String" - } - - "should allow traversing to method" taggedAs DifferentInNewFrontend in { - cpg.methodReturn.method.name.l shouldBe List("foo", ":program", ".assignment") - } - - "should allow traversing to file" in { - cpg.method.name("foo").file.name.l should not be empty - } - - // TODO: need to be fixed. - "test corresponding type, typeDecl and binding" ignore { - cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.l should not be empty - val bindingTypeDecl = - cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.head - - bindingTypeDecl.name shouldBe "foo" - bindingTypeDecl.fullName shouldBe "Test0.rb::program.foo" - bindingTypeDecl.referencingType.name.head shouldBe "foo" - bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program.foo" - } - - "test method parameter nodes" in { - cpg.method.name("foo").parameter.name.l.size shouldBe 2 - val parameter1 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(1).head - parameter1.name shouldBe "a" - parameter1.index shouldBe 1 - parameter1.typeFullName shouldBe "ANY" - - val parameter2 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(2).head - parameter2.name shouldBe "b" - parameter2.index shouldBe 2 - parameter2.typeFullName shouldBe "ANY" - } - - "should allow traversing from parameter to method" in { - cpg.parameter.name("a").method.name.l shouldBe List("foo") - cpg.parameter.name("b").method.name.l shouldBe List("foo") - } - } - - "Method with variable arguments" should { - val cpg = code(""" - |def foo(*names) - | return "" - |end - |""".stripMargin) - - "Variable argument properties should be rightly set" in { - cpg.parameter.name("names").l.size shouldBe 1 - val param = cpg.parameter.name("names").l.head - param.isVariadic shouldBe true - } - } - - "Multiple Return tests" should { - val cpg = code(""" - |def foo(names) - | if names == "Alice" - | return 1 - | else - | return 2 - | end - |end - |""".stripMargin) - - "be correct for multiple returns" in { - cpg.method("foo").methodReturn.l.size shouldBe 1 - cpg.method("foo").ast.isReturn.l.size shouldBe 2 - inside(cpg.method("foo").methodReturn.l) { case List(fooReturn) => - fooReturn.typeFullName shouldBe "ANY" - } - val astReturns = cpg.method("foo").ast.isReturn.l - inside(astReturns) { case List(ret1, ret2) => - ret1.code shouldBe "return 1" - ret1.lineNumber shouldBe Option(4) - ret2.code shouldBe "return 2" - ret2.lineNumber shouldBe Option(6) - } - } - } - - "Function with empty array in block" should { - val cpg = code(""" - |def foo - | [] - |end - |""".stripMargin) - - "contain empty array" taggedAs SameInNewFrontend in { - cpg.method.name("foo").size shouldBe 1 - cpg.method.name("foo").block.containsCallTo(Operators.arrayInitializer).size shouldBe 1 - } - } - - "Function as a list element in accessor" should { - val cpg = code(""" - |class Bar - | attr_accessor :a, - | :b, - | def self.c - | 1 - | end - |end - |""".stripMargin) - - "contain empty array" taggedAs DifferentInNewFrontend in { - cpg.identifier("c").astParent.isCallTo("attr_accessor").size shouldBe 1 - } - } - - "Function for private_class_method" should { - val cpg = code(""" - |private_class_method def foo(a) - | b - |end - |""".stripMargin) - - "have function identifier as argument and function definition" in { - // one from the METHOD_REF node and one on line `def self.c` - cpg.identifier("foo").astParent.isCallTo("private_class_method").size shouldBe 1 - cpg.method.nameExact("foo").size shouldBe 1 - } - - "Function for multiple function prefixes" should { - val cpg = code(""" - |class Foo - | private attr_reader :bar - | - | def bar - | x - | end - |end - |""".stripMargin) - - "have function identifier as argument and function definition" taggedAs DifferentInNewFrontend ignore { - /* FIXME: We are capturing the prefixes but order in ast is private -> attr_reader -> LITERAL(bar) - * We should duplicate the bar node and set parent as both methods */ - cpg.identifier("bar").astParent.isCallTo("private").size shouldBe 1 - cpg.identifier("bar").astParent.isCallTo("attr_reader").size shouldBe 1 - cpg.method.nameExact("bar").size shouldBe 1 - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodTwoTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodTwoTests.scala deleted file mode 100644 index 4037cebde09f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/MethodTwoTests.scala +++ /dev/null @@ -1,103 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes} -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal - -class MethodTwoTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "Method test with define_method" should { - val cpg = code(""" - |define_method(:foo) do |a, b| - | return "" - |end - |""".stripMargin) - - // TODO: This test cases needs to be fixed. - "should contain exactly one method node with correct fields" ignore { - inside(cpg.method.name("foo").l) { case List(x) => - x.name shouldBe "foo" - x.isExternal shouldBe false - x.fullName shouldBe "Test0.rb::program:foo" - x.code should startWith("def foo(a, b)") - x.isExternal shouldBe false - x.order shouldBe 1 - x.filename.endsWith("Test0.rb") - x.lineNumber shouldBe Option(2) - x.lineNumberEnd shouldBe Option(4) - } - } - - // TODO: This test cases needs to be fixed. - "should return correct number of lines" ignore { - cpg.method.name("foo").numberOfLines.l shouldBe List(3) - } - - // TODO: This test cases needs to be fixed. - "should allow traversing to parameters" ignore { - cpg.method.name("foo").parameter.name.toSetMutable shouldBe Set("a", "b") - } - - // TODO: This test cases needs to be fixed. - "should allow traversing to methodReturn" ignore { - cpg.method.name("foo").methodReturn.l.size shouldBe 1 - cpg.method.name("foo").methodReturn.typeFullName.head shouldBe "ANY" - } - - // TODO: This test cases needs to be fixed. - "should allow traversing to method" ignore { - cpg.methodReturn.method.name.l shouldBe List("foo", ":program") - } - - // TODO: This test cases needs to be fixed. - "should allow traversing to file" ignore { - cpg.method.name("foo").file.name.l should not be empty - } - - // TODO: Need to be fixed - "test function method ref" ignore { - cpg.methodRef("foo").referencedMethod.fullName.l should not be empty - cpg.methodRef("foo").referencedMethod.fullName.head shouldBe - "Test0.rb::program:foo" - } - - // TODO: Need to be fixed. - "test existence of local variable in module function" ignore { - cpg.method.fullName("Test0.rb::program").local.name.l should contain("foo") - } - - // TODO: need to be fixed. - "test corresponding type, typeDecl and binding" ignore { - cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.l should not be empty - val bindingTypeDecl = - cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.head - - bindingTypeDecl.name shouldBe "foo" - bindingTypeDecl.fullName shouldBe "Test0.rb::program:foo" - bindingTypeDecl.referencingType.name.head shouldBe "foo" - bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program:foo" - } - - // TODO: Need to be fixed - "test method parameter nodes" ignore { - - cpg.method.name("foo").parameter.name.l.size shouldBe 2 - val parameter1 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(1).head - parameter1.name shouldBe "a" - parameter1.index shouldBe 1 - parameter1.typeFullName shouldBe "ANY" - - val parameter2 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(2).head - parameter2.name shouldBe "b" - parameter2.index shouldBe 2 - parameter2.typeFullName shouldBe "ANY" - } - - // TODO: Need to be fixed - "should allow traversing from parameter to method" ignore { - cpg.parameter.name("a").method.name.l shouldBe List("foo") - cpg.parameter.name("b").method.name.l shouldBe List("foo") - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ModuleTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ModuleTests.scala deleted file mode 100644 index 9ac50dd70fc8..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ModuleTests.scala +++ /dev/null @@ -1,210 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.FileTraversal -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.codepropertygraph.generated.Operators -class ModuleTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "Simple module checks" should { - val cpg = code(""" - |module MyNamespace - | MY_CONSTANT = 20 - |end - |""".stripMargin) - "Check namespace basic block structure" in { - cpg.namespaceBlock - .nameNot(NamespaceTraversal.globalNamespaceName) - .filenameNot(FileTraversal.UNKNOWN) - .l - .size shouldBe 1 - val List(x) = cpg.namespaceBlock - .nameNot(NamespaceTraversal.globalNamespaceName) - .filenameNot(FileTraversal.UNKNOWN) - .l - x.name shouldBe "MyNamespace" - x.fullName shouldBe "Test0.rb::program.MyNamespace" - } - - "Respective dummy Method in place" in { - cpg.method(XDefines.StaticInitMethodName).l.size shouldBe 1 - val List(x) = cpg.method(XDefines.StaticInitMethodName).l - x.fullName shouldBe s"Test0.rb::program.MyNamespace.${XDefines.StaticInitMethodName}" - } - - "Respective dummy TypeDecl in place" in { - cpg.typeDecl("MyNamespace").l.size shouldBe 1 - val List(x) = cpg.typeDecl("MyNamespace").l - x.fullName shouldBe s"Test0.rb::program.MyNamespace" - } - } - - "Hierarchical module checks" should { - val cpg = code(""" - |module MyNamespaceParent - | module MyNamespaceChild - | SOME_CONSTATN = 10 - | end - |end - |""".stripMargin) - "Check namespace basic block structure" in { - cpg.namespaceBlock - .nameNot(NamespaceTraversal.globalNamespaceName) - .filenameNot(FileTraversal.UNKNOWN) - .l - .size shouldBe 2 - val List(x, x1) = cpg.namespaceBlock - .nameNot(NamespaceTraversal.globalNamespaceName) - .filenameNot(FileTraversal.UNKNOWN) - .l - x.name shouldBe "MyNamespaceParent" - x.fullName shouldBe s"Test0.rb::program.MyNamespaceParent" - - x1.name shouldBe "MyNamespaceChild" - x1.fullName shouldBe s"Test0.rb::program.MyNamespaceParent.MyNamespaceChild" - } - - "Respective dummy Method in place" in { - cpg.method(XDefines.StaticInitMethodName).l.size shouldBe 2 - cpg.method(XDefines.StaticInitMethodName).fullName.l shouldBe List( - s"Test0.rb::program.MyNamespaceParent.${XDefines.StaticInitMethodName}", - s"Test0.rb::program.MyNamespaceParent.MyNamespaceChild.${XDefines.StaticInitMethodName}" - ) - } - - "Respective dummy TypeDecl in place" in { - cpg.typeDecl("MyNamespaceChild").l.size shouldBe 1 - val List(x) = cpg.typeDecl("MyNamespaceChild").l - x.fullName shouldBe s"Test0.rb::program.MyNamespaceParent.MyNamespaceChild" - } - } - - "Module Internal structure checks with member variable" should { - val cpg = code(""" - |module MyNamespace - | @@plays = 0 - | class MyClass - | def method1 - | puts "Method 1" - | end - | end - |end - |""".stripMargin) - "Class structure in plcae" in { - cpg.typeDecl("MyClass").l.size shouldBe 1 - val List(x) = cpg.typeDecl("MyClass").l - x.fullName shouldBe s"Test0.rb::program.MyNamespace.MyClass" - } - - "Class Method structure in place" in { - cpg.method("method1").l.size shouldBe 1 - val List(x) = cpg.method("method1").l - x.fullName shouldBe s"Test0.rb::program.MyNamespace.MyClass.method1" - } - - "member variables structure in place" in { - val List(classInit) = cpg.method(XDefines.StaticInitMethodName).l - classInit.fullName shouldBe s"Test0.rb::program.MyNamespace.${XDefines.StaticInitMethodName}" - val List(playsDef) = classInit.call.nameExact(Operators.fieldAccess).fieldAccess.l - playsDef.fieldIdentifier.canonicalName.headOption shouldBe Option("plays") - - val List(myclassTd) = cpg.typeDecl("MyNamespace").l - val List(plays) = myclassTd.member.l - plays.name shouldBe "plays" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("plays") - } - } - - "Module internal structure checks with Constant defined in module" should { - val cpg = code(""" - |module MyNamespace - | MY_CONSTANT = 0 - |end - |""".stripMargin) - // TODO Ignoring below test case and the function where this is implemented treats every UpperCase node as Constant which is incorrect and causing conflicts elsewhere - "member variables structure in place" ignore { - val List(moduleInit) = cpg.method(XDefines.StaticInitMethodName).l - moduleInit.fullName shouldBe s"Test0.rb::program.MyNamespace.${XDefines.StaticInitMethodName}" - val List(myconstant) = moduleInit.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstant.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - - val List(myclassTd) = cpg.typeDecl("MyNamespace").l - val List(myConstant) = myclassTd.member.l - myConstant.name shouldBe "MY_CONSTANT" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("MY_CONSTANT") - } - } - - "Hierarchical module checks with constants" ignore { - val cpg = code(""" - |module MyNamespace - | MY_CONSTANT = 0 - | @@plays = 0 - | module ChildModule - | @@name = 0 - | MY_CONSTANT = 0 - | end - |end - |""".stripMargin) - - "member variables structure in place" in { - val List(modInit1, modInit2) = cpg.method(XDefines.StaticInitMethodName).l - modInit1.fullName shouldBe s"Test0.rb::program.MyNamespace.${XDefines.StaticInitMethodName}" - val List(myconstantfa, playsfa) = modInit1.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstantfa.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - playsfa.fieldIdentifier.canonicalName.headOption shouldBe Option("plays") - - modInit2.fullName shouldBe s"Test0.rb::program.MyNamespace.ChildModule.${XDefines.StaticInitMethodName}" - val List(namefa, myconstant2fa) = modInit2.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstant2fa.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - namefa.fieldIdentifier.canonicalName.headOption shouldBe Option("name") - - val List(myclassTd2) = cpg.typeDecl("ChildModule").l - val List(namem, myConstant2m) = myclassTd2.member.l - myConstant2m.name shouldBe "MY_CONSTANT" - namem.name shouldBe "name" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("name", "MY_CONSTANT") - - val List(myclassTd) = cpg.typeDecl("MyNamespace").l - val List(myconstantm, playsm) = myclassTd.member.l - myconstantm.name shouldBe "MY_CONSTANT" - playsm.name shouldBe "plays" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("MY_CONSTANT", "plays") - - } - } - - "Class inside module checks with constants" ignore { - val cpg = code(""" - |module MyNamespace - | MY_CONSTANT = 0 - | class ChildCls - | MY_CONSTANT = 0 - | end - |end - |""".stripMargin) - - "member variables structure in place" in { - val List(modInit1, modInit2) = cpg.method(XDefines.StaticInitMethodName).l - modInit1.fullName shouldBe s"Test0.rb::program.MyNamespace.${XDefines.StaticInitMethodName}" - val List(myconstant) = modInit1.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstant.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - - modInit2.fullName shouldBe s"Test0.rb::program.MyNamespace.ChildCls.${XDefines.StaticInitMethodName}" - val List(myconstant2) = modInit2.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstant2.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - - val List(myclassTd) = cpg.typeDecl("MyNamespace").l - val List(myConstant) = myclassTd.member.l - myConstant.name shouldBe "MY_CONSTANT" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("MY_CONSTANT") - - val List(myclassTd2) = cpg.typeDecl("ChildCls").l - val List(myConstant2) = myclassTd2.member.l - myConstant2.name shouldBe "MY_CONSTANT" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("MY_CONSTANT") - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/NamespaceBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/NamespaceBlockTest.scala deleted file mode 100644 index d79061ac44e5..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/NamespaceBlockTest.scala +++ /dev/null @@ -1,55 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.language.types.structure.FileTraversal -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -import io.joern.x2cpg.Defines - -class NamespaceBlockTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - val cpg = code("""puts 123 - |def foo() - |end - |class MyClass - |end - |""".stripMargin) - - "should contain a correct global namespace block for the `` file" in { - val List(x) = cpg.namespaceBlock.filename(FileTraversal.UNKNOWN).l - x.name shouldBe NamespaceTraversal.globalNamespaceName - x.fullName shouldBe NamespaceTraversal.globalNamespaceName - x.order shouldBe 1 - } - - "should contain correct namespace block for known file" in { - val List(x) = cpg.namespaceBlock.filenameNot(FileTraversal.UNKNOWN).l - x.name shouldBe NamespaceTraversal.globalNamespaceName - x.filename should not be empty - x.fullName shouldBe s"${x.filename}:${NamespaceTraversal.globalNamespaceName}" - x.order shouldBe 1 - } - - "should allow traversing from namespace block to method" in { - cpg.namespaceBlock.filenameNot(FileTraversal.UNKNOWN).ast.isMethod.name.l shouldBe List( - ":program", - "foo", - Defines.ConstructorMethodName - ) - } - - "should allow traversing from namespace block to type declaration" in { - cpg.namespaceBlock - .filenameNot(FileTraversal.UNKNOWN) - .ast - .isTypeDecl - .nameNot(NamespaceTraversal.globalNamespaceName) - .name - .l shouldBe List("MyClass") - } - - "should allow traversing from namespace block to namespace" in { - cpg.namespaceBlock.filenameNot(FileTraversal.UNKNOWN).namespace.name.l shouldBe List( - NamespaceTraversal.globalNamespaceName - ) - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/RescueKeywordCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/RescueKeywordCpgTests.scala deleted file mode 100644 index 950d9c752226..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/RescueKeywordCpgTests.scala +++ /dev/null @@ -1,50 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes, DispatchTypes, Operators, nodes} -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal - -class RescueKeywordCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - "rescue in the immediate scope of a `def` block" in { - val cpg = code("""def foo - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin) - - val methodNode = cpg.method.name("foo").head - methodNode.name shouldBe "foo" - methodNode.numberOfLines shouldBe 4 - methodNode.astChildren.isBlock.astChildren.code.contains("try") shouldBe true - - val zeroDivisionErrorIdentifier = cpg.identifier("ZeroDivisionError").head - zeroDivisionErrorIdentifier.code shouldBe "ZeroDivisionError" - zeroDivisionErrorIdentifier.astSiblings.isIdentifier.head.name shouldBe "e" - zeroDivisionErrorIdentifier.astParent.isBlock shouldBe true - } - - "rescue in the immediate scope of a `do` block" ignore { - val cpg = code("""foo x do |y| - |y/0 - |rescue ZeroDivisionError => e - |end""".stripMargin) - - val zeroDivisionErrorIdentifier = cpg.identifier("ZeroDivisionError").head - zeroDivisionErrorIdentifier.code shouldBe "ZeroDivisionError" - zeroDivisionErrorIdentifier.astSiblings.isIdentifier.head.name shouldBe "e" - zeroDivisionErrorIdentifier.astParent.isBlock shouldBe true - } - - "rescue in the immediate scope of a `begin` block" in { - val cpg = code("""begin - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin) - - val zeroDivisionErrorIdentifier = cpg.identifier("ZeroDivisionError").head - zeroDivisionErrorIdentifier.code shouldBe "ZeroDivisionError" - zeroDivisionErrorIdentifier.astSiblings.isIdentifier.head.name shouldBe "e" - zeroDivisionErrorIdentifier.astParent.isBlock shouldBe true - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ReturnTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ReturnTests.scala deleted file mode 100644 index 9a85b3d1a26f..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/ReturnTests.scala +++ /dev/null @@ -1,22 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.{DifferentInNewFrontend, RubyCode2CpgFixture} -import io.shiftleft.codepropertygraph.generated.nodes.MethodRef -import io.shiftleft.semanticcpg.language.* - -class ReturnTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "a method, where the last statement is a method" should { - val cpg = code(""" - |Row = Struct.new(:cancel_date) do - | def end_date = cancel_date - |end - |""".stripMargin) - - "return a method ref" taggedAs DifferentInNewFrontend in { - val List(mRef: MethodRef) = cpg.method("new2").ast.isReturn.astChildren.l: @unchecked - mRef.methodFullName shouldBe "Test0.rb::program.end_date" - } - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/SimpleAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/SimpleAstCreationPassTest.scala deleted file mode 100644 index fba69fcc69bf..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/SimpleAstCreationPassTest.scala +++ /dev/null @@ -1,1486 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.deprecated.astcreation.AstCreator -import io.joern.rubysrc2cpg.deprecated.passes.Defines -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal, NewIdentifier} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.semanticcpg.language.* - -class SimpleAstCreationPassTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "AST generation for simple fragments" should { - - "have correct structure for a single command call" in { - val cpg = code("""puts 123""") - - val List(assign, puts) = cpg.call.l - val List(arg) = puts.argument.isLiteral.l - - puts.code shouldBe "puts 123" - puts.lineNumber shouldBe Some(1) - - arg.code shouldBe "123" - arg.lineNumber shouldBe Some(1) - arg.columnNumber shouldBe Some(5) - - assign.name shouldBe ".assignment" // call node for builtin typeRef assignment - } - - "have correct structure for an unsigned, decimal integer literal" in { - val cpg = code("123") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "123" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a +integer, decimal literal" in { - val cpg = code("+1") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "+1" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a -integer, decimal literal" in { - val cpg = code("-1") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "-1" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an unsigned, decimal float literal" in { - val cpg = code("3.14") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Float" - literal.code shouldBe "3.14" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a +float, decimal literal" in { - val cpg = code("+3.14") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Float" - literal.code shouldBe "+3.14" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a -float, decimal literal" in { - val cpg = code("-3.14") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Float" - literal.code shouldBe "-3.14" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an unsigned, decimal float literal with unsigned exponent" in { - val cpg = code("3e10") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Float" - literal.code shouldBe "3e10" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an unsigned, decimal float literal with -exponent" in { - val cpg = code("12e-10") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Float" - literal.code shouldBe "12e-10" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an unsigned, binary integer literal" in { - val cpg = code("0b01") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "0b01" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a -integer, binary literal" in { - val cpg = code("-0b01") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "-0b01" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a +integer, binary literal" in { - val cpg = code("+0b01") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "+0b01" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an unsigned, hexadecimal integer literal" in { - val cpg = code("0xabc") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "0xabc" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a -integer, hexadecimal literal" in { - val cpg = code("-0xa") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "-0xa" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a +integer, hexadecimal literal" in { - val cpg = code("+0xa") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.Integer" - literal.code shouldBe "+0xa" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for `nil` literal" in { - val cpg = code("puts nil") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.NilClass - literal.code shouldBe "nil" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(5) - } - - "have correct structure for `true` literal" in { - val cpg = code("puts true") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.TrueClass - literal.code shouldBe "true" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(5) - } - - "have correct structure for `false` literal" in { - val cpg = code("puts false") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.FalseClass - literal.code shouldBe "false" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(5) - } - - "have correct structure for `self` identifier" in { - val cpg = code("puts self") - val List(self, _) = cpg.identifier.l - self.typeFullName shouldBe Defines.Object - self.code shouldBe "self" - self.lineNumber shouldBe Some(1) - self.columnNumber shouldBe Some(5) - } - - "have correct structure for `__FILE__` identifier" in { - val cpg = code("puts __FILE__") - val List(file, _) = cpg.identifier.l - file.typeFullName shouldBe "__builtin.String" - file.code shouldBe "__FILE__" - file.lineNumber shouldBe Some(1) - file.columnNumber shouldBe Some(5) - } - - "have correct structure for `__LINE__` identifier" in { - val cpg = code("puts __LINE__") - val List(line, _) = cpg.identifier.l - line.typeFullName shouldBe "__builtin.Integer" - line.code shouldBe "__LINE__" - line.lineNumber shouldBe Some(1) - line.columnNumber shouldBe Some(5) - } - - "have correct structure for `__ENCODING__` identifier" in { - val cpg = code("puts __ENCODING__") - val List(encoding, _) = cpg.identifier.l - encoding.typeFullName shouldBe Defines.Encoding - encoding.code shouldBe "__ENCODING__" - encoding.lineNumber shouldBe Some(1) - encoding.columnNumber shouldBe Some(5) - } - - "have correct structure for a single-line double-quoted string literal" in { - val cpg = code("\"hello\"") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.String" - literal.code shouldBe "\"hello\"" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a single-line single-quoted string literal" in { - val cpg = code("'hello'") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.String" - literal.code shouldBe "'hello'" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a single-line quoted non-expanded string literal" in { - val cpg = code("%q(hello)") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.String" - literal.code shouldBe "%q(hello)" - literal.lineNumber shouldBe Some(1) - } - - "have correct structure for a multi-line quoted non-expanded string literal" in { - val cpg = code("""%q< - |xyz - |123 - |>""".stripMargin) - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe "__builtin.String" - literal.code shouldBe - """%q< - |xyz - |123 - |>""".stripMargin - literal.lineNumber shouldBe Some(1) - } - - "have correct structure for an identifier symbol literal" in { - val cpg = code(":someSymbolName") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Symbol - literal.code shouldBe ":someSymbolName" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a single-quoted-string symbol literal" in { - val cpg = code(":'someSymbolName'") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Symbol - literal.code shouldBe ":'someSymbolName'" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for an identifier symbol literal used in an `undef` statement" in { - val cpg = code("undef :symbolName") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Symbol - literal.code shouldBe ":symbolName" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(6) - } - - "have correct structure for a single-line regular expression literal" in { - val cpg = code("/(eu|us)/") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Regexp - literal.code shouldBe "/(eu|us)/" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(0) - } - - "have correct structure for a single-line quoted (%r) regular expression literal" in { - val cpg = code("%r{eu|us}") - val List(literalNode) = cpg.literal.l - literalNode.typeFullName shouldBe Defines.Regexp - literalNode.code shouldBe "%r{eu|us}" - literalNode.lineNumber shouldBe Some(1) - literalNode.columnNumber shouldBe Some(0) - } - - "have correct structure for an empty regular expression literal used as the second argument to a call" in { - val cpg = code("puts(x, //)") - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Regexp - literal.code shouldBe "//" - literal.lineNumber shouldBe Some(1) - literal.columnNumber shouldBe Some(8) - } - - "have correct structure for a single-line regular expression literal passed as argument to a command" in { - val cpg = code("puts /x/") - - val List(_, callNode) = cpg.call.l - callNode.code shouldBe "puts /x/" - callNode.name shouldBe "puts" - callNode.lineNumber shouldBe Some(1) - - val List(literalArg) = callNode.argument.isLiteral.l - literalArg.argumentIndex shouldBe 1 - literalArg.typeFullName shouldBe Defines.Regexp - literalArg.code shouldBe "/x/" - literalArg.lineNumber shouldBe Some(1) - } - - "have correct structure for a single left had side call" in { - val cpg = code("array[n] = 10") - val List(callNode) = cpg.call.name(Operators.indexAccess).l - callNode.code shouldBe "array[n]" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(5) - } - - "have correct structure for a binary expression" in { - val cpg = code("x+y") - val List(callNode) = cpg.call.name(Operators.addition).l - callNode.code shouldBe "x+y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a not expression" in { - val cpg = code("not y") - val List(callNode) = cpg.call.name(Operators.not).l - callNode.code shouldBe "not y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a power expression" in { - val cpg = code("x**y") - val List(callNode) = cpg.call.name(Operators.exponentiation).l - callNode.code shouldBe "x**y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a inclusive range expression" in { - val cpg = code("1..10") - val List(callNode) = cpg.call.name(Operators.range).l - callNode.code shouldBe "1..10" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a non-inclusive range expression" in { - val cpg = code("1...10") - val List(callNode) = cpg.call.name(Operators.range).l - callNode.code shouldBe "1...10" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a relational expression" in { - val cpg = code("x> y") - val List(callNode) = cpg.call.name(Operators.logicalShiftRight).l - callNode.code shouldBe "x >> y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a shift left expression" in { - val cpg = code("x << y") - val List(callNode) = cpg.call.name(Operators.shiftLeft).l - callNode.code shouldBe "x << y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a compare expression" in { - val cpg = code("x <=> y") - val List(callNode) = cpg.call.name(Operators.compare).l - callNode.code shouldBe "x <=> y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a indexing expression" in { - val cpg = code("def some_method(index)\n some_map[index]\nend") - val List(callNode) = cpg.call.name(Operators.indexAccess).l - callNode.code shouldBe "some_map[index]" - callNode.lineNumber shouldBe Some(2) - callNode.columnNumber shouldBe Some(9) - } - - "have correct structure for overloaded index operator method" in { - val cpg = code(""" - |class MyClass - |def [](key) - | @member_hash[key] - |end - |end - |""".stripMargin) - - val List(methodNode) = cpg.method.name("\\[]").l - methodNode.fullName shouldBe "Test0.rb::program.MyClass.[]" - methodNode.code shouldBe "def [](key)\n @member_hash[key]\nend" - methodNode.lineNumber shouldBe Some(3) - methodNode.lineNumberEnd shouldBe Some(5) - methodNode.columnNumber shouldBe Some(4) - } - - "have correct structure for overloaded equality operator method" in { - val cpg = code(""" - |class MyClass - |def ==(other) - | @my_member==other - |end - |end - |""".stripMargin) - - val List(methodNode) = cpg.method.name("==").l - methodNode.fullName shouldBe "Test0.rb::program.MyClass.==" - methodNode.code shouldBe "def ==(other)\n @my_member==other\nend" - methodNode.lineNumber shouldBe Some(3) - methodNode.lineNumberEnd shouldBe Some(5) - methodNode.columnNumber shouldBe Some(4) - } - - "have correct structure for class method" in { - val cpg = code(""" - |class MyClass - |def some_method(param) - |end - |end - |""".stripMargin) - - val List(methodNode) = cpg.method.name("some_method").l - methodNode.fullName shouldBe "Test0.rb::program.MyClass.some_method" - methodNode.code shouldBe "def some_method(param)\nend" - methodNode.lineNumber shouldBe Some(3) - methodNode.lineNumberEnd shouldBe Some(4) - methodNode.columnNumber shouldBe Some(4) - } - - "have correct structure for scope resolution operator call" in { - val cpg = code(""" - |def foo(param) - |::SomeConstant = param - |end - |""".stripMargin) - - val List(identifierNode) = cpg.identifier.name("SomeConstant").l - identifierNode.code shouldBe "SomeConstant" - identifierNode.lineNumber shouldBe Some(3) - identifierNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a addition expression with space before addition" in { - val cpg = code("x + y") - val List(callNode) = cpg.call.name(Operators.addition).l - callNode.code shouldBe "x + y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for a addition expression with space before subtraction" in { - val cpg = code("x - y") - val List(callNode) = cpg.call.name(Operators.subtraction).l - callNode.code shouldBe "x - y" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for object's method access (chainedInvocationPrimary)" in { - val cpg = code("object.some_method(arg1,arg2)") - val List(callNode) = cpg.call.name("some_method").l - callNode.code shouldBe "object.some_method(arg1,arg2)" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(6) - - val List(identifierNode1) = cpg.identifier.name("arg1").l - identifierNode1.code shouldBe "arg1" - identifierNode1.lineNumber shouldBe Some(1) - identifierNode1.columnNumber shouldBe Some(19) - - val List(identifierNode2) = cpg.identifier.name("arg2").l - identifierNode2.code shouldBe "arg2" - identifierNode2.lineNumber shouldBe Some(1) - identifierNode2.columnNumber shouldBe Some(24) - } - - "have correct structure for object's method.member access (chainedInvocationPrimary)" ignore { - val cpg = code("object.some_member") - val List(identifierNode) = cpg.identifier.name("some_member").l - identifierNode.code shouldBe "some_member" - identifierNode.lineNumber shouldBe Some(1) - identifierNode.columnNumber shouldBe Some(0) - } - - "have correct structure for negation before block (invocationExpressionOrCommand)" in { - val cpg = code("!foo arg do\nputs arg\nend") - - val List(callNode1) = cpg.call.name(Operators.not).l - callNode1.code shouldBe "!foo arg do\nputs arg\nend" - callNode1.lineNumber shouldBe Some(1) - callNode1.columnNumber shouldBe Some(0) - - val List(callNode2) = cpg.call.name("foo").l - callNode2.code shouldBe "foo arg do\nputs arg\nend" - callNode2.lineNumber shouldBe Some(1) - callNode2.columnNumber shouldBe Some(1) - - val List(callNode3) = cpg.call.name("puts").l - callNode3.code shouldBe "puts arg" - callNode3.lineNumber shouldBe Some(2) - callNode3.columnNumber shouldBe Some(0) - - val List(argArgumentOfPuts, argArgumentOfFoo) = cpg.identifier.name("arg").l - argArgumentOfFoo.code shouldBe "arg" - argArgumentOfFoo.lineNumber shouldBe Some(2) - argArgumentOfFoo.columnNumber shouldBe Some(5) - - argArgumentOfPuts.code shouldBe "arg" - argArgumentOfPuts.lineNumber shouldBe Some(1) - } - - "have correct structure for a hash initialisation" in { - val cpg = code("hashMap = {\"k1\" => 1, \"k2\" => 2}") - val callNodes = cpg.call.name(".keyValueAssociation").l - callNodes.size shouldBe 2 - callNodes.head.code shouldBe "\"k1\" => 1" - callNodes.head.lineNumber shouldBe Some(1) - callNodes.head.columnNumber shouldBe Some(16) - } - - "have correct structure for defined? command" in { - val cpg = code("defined? x") - - val List(callNode) = cpg.call.name(".defined").l - callNode.code shouldBe "defined? x" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - - val List(identifierNode) = cpg.identifier.name("x").l - identifierNode.code shouldBe "x" - identifierNode.lineNumber shouldBe Some(1) - identifierNode.columnNumber shouldBe Some(9) - } - - "have correct structure for defined? call" in { - val cpg = code("defined?(x)") - - val List(callNode) = cpg.call.name(".defined").l - callNode.code shouldBe "defined?(x)" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - - val List(identifierNode) = cpg.identifier.name("x").l - identifierNode.code shouldBe "x" - identifierNode.lineNumber shouldBe Some(1) - identifierNode.columnNumber shouldBe Some(9) - } - - "have correct structure for chainedInvocationWithoutArgumentsPrimary" in { - val cpg = code("object::foo do\nputs \"right here\"\nend") - - val List(callNode1) = cpg.call.name("foo").l - callNode1.code shouldBe "puts \"right here\"" - callNode1.lineNumber shouldBe Some(1) - callNode1.columnNumber shouldBe Some(3) - - val List(callNode2) = cpg.call.name("puts").l - callNode2.code shouldBe "puts \"right here\"" - callNode2.lineNumber shouldBe Some(2) - callNode2.columnNumber shouldBe Some(0) - } - - "have correct structure for require with an expression" in { - val cpg = code("Dir[Rails.root.join('a', 'b', '**', '*.rb')].each { |f| require f }") - - val List(callNode) = cpg.call.name("require").l - callNode.code shouldBe "require f" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(56) - } - - "have correct structure for undef" in { - val cpg = code("undef method1,method2") - - val List(callNode) = cpg.call.name(".undef").l - callNode.code shouldBe "undef method1,method2" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(0) - } - - "have correct structure for ternary if expression" in { - val cpg = code("a ? b : c") - val List(controlNode) = cpg.controlStructure.l - - controlNode.controlStructureType shouldBe ControlStructureTypes.IF - controlNode.code shouldBe "a ? b : c" - controlNode.lineNumber shouldBe Some(1) - controlNode.columnNumber shouldBe Some(0) - - val List(a) = controlNode.condition.isIdentifier.l - a.code shouldBe "a" - a.name shouldBe "a" - a.lineNumber shouldBe Some(1) - a.columnNumber shouldBe Some(0) - - val List(_, b, c) = controlNode.astChildren.isIdentifier.l - b.code shouldBe "b" - b.name shouldBe "b" - b.lineNumber shouldBe Some(1) - b.columnNumber shouldBe Some(4) - - c.code shouldBe "c" - c.name shouldBe "c" - c.lineNumber shouldBe Some(1) - c.columnNumber shouldBe Some(8) - } - - "have correct structure for if statement" in { - val cpg = code("""if x == 0 then - | puts 1 - |end - |""".stripMargin) - - val List(ifNode) = cpg.controlStructure.l - ifNode.controlStructureType shouldBe ControlStructureTypes.IF - ifNode.lineNumber shouldBe Some(1) - - val List(ifCondition, ifBlock) = ifNode.astChildren.l - ifCondition.code shouldBe "x == 0" - ifCondition.lineNumber shouldBe Some(1) - - val List(puts) = ifBlock.astChildren.l - puts.code shouldBe "puts 1" - puts.lineNumber shouldBe Some(2) - } - - "have correct structure for if-else statement" in { - val cpg = code("""if x == 0 then - | puts 1 - |else - | puts 2 - |end - |""".stripMargin) - - val List(ifNode) = cpg.controlStructure.l - ifNode.controlStructureType shouldBe ControlStructureTypes.IF - ifNode.lineNumber shouldBe Some(1) - - val List(ifCondition, ifBlock, elseBlock) = ifNode.astChildren.l - ifCondition.code shouldBe "x == 0" - ifCondition.lineNumber shouldBe Some(1) - - val List(puts1) = ifBlock.astChildren.l - puts1.code shouldBe "puts 1" - puts1.lineNumber shouldBe Some(2) - - val List(puts2) = elseBlock.astChildren.l - puts2.code shouldBe "puts 2" - puts2.lineNumber shouldBe Some(4) - } - - "have correct structure for class definition with body having only identifiers" in { - val cpg = code("class MyClass\nidentifier1\nidentifier2\nend") - - val List(identifierNode1) = cpg.identifier.name("identifier1").l - identifierNode1.code shouldBe "identifier1" - identifierNode1.lineNumber shouldBe Some(2) - identifierNode1.columnNumber shouldBe Some(0) - - val List(identifierNode2) = cpg.identifier.name("identifier2").l - identifierNode2.code shouldBe "identifier2" - identifierNode2.lineNumber shouldBe Some(3) - identifierNode2.columnNumber shouldBe Some(0) - } - - // NOTE: The representation for `super` may change, in order to accommodate its meaning. - // But until then, modelling it as a call seems the appropriate thing to do. - "have correct structure for `super` expression call without block" in { - val cpg = code("super(1)") - - val List(callNode) = cpg.call.l - callNode.code shouldBe "super(1)" - callNode.name shouldBe ".super" - callNode.lineNumber shouldBe Some(1) - - val List(literalArg) = callNode.argument.isLiteral.l - literalArg.argumentIndex shouldBe 1 - literalArg.code shouldBe "1" - literalArg.lineNumber shouldBe Some(1) - } - - "have correct structure for `super` command call without block" in { - val cpg = code("super 1") - - val List(callNode) = cpg.call.l - callNode.code shouldBe "super 1" - callNode.name shouldBe ".super" - callNode.lineNumber shouldBe Some(1) - - val List(literalArg) = callNode.argument.isLiteral.l - literalArg.argumentIndex shouldBe 1 - literalArg.code shouldBe "1" - literalArg.lineNumber shouldBe Some(1) - } - - "have generated call nodes for regex interpolation" in { - val cpg = code("/x#{Regexp.quote(foo)}b#{x+'z'}a/") - val List(literalNode) = cpg.literal.l - cpg.call.size shouldBe 2 - literalNode.code shouldBe "'z'" - } - - "have correct structure for keyword? named method usage usage" in { - val cpg = code("x = 1.nil?") - - val List(callNode) = cpg.call.nameExact("nil?").l - callNode.code shouldBe "1.nil?" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(5) - - val List(arg) = callNode.argument.isLiteral.l - arg.code shouldBe "1" - } - - "have correct structure for keyword usage inside association" in { - val cpg = code("foo if: x.nil?") - - val List(callNode) = cpg.call.nameExact("nil?").l - callNode.code shouldBe "x.nil?" - callNode.lineNumber shouldBe Some(1) - callNode.columnNumber shouldBe Some(9) - - val List(arg) = callNode.argument.isIdentifier.l - arg.code shouldBe "x" - - val List(assocCallNode) = cpg.call.nameExact(".activeRecordAssociation").l - assocCallNode.code shouldBe "if: x.nil?" - assocCallNode.lineNumber shouldBe Some(1) - assocCallNode.columnNumber shouldBe Some(6) - - assocCallNode.argument.size shouldBe 2 - assocCallNode.argument.argumentIndex(1).head.code shouldBe "if" - assocCallNode.argument.argumentIndex(2).head.code shouldBe "x.nil?" - } - - "have correct structure for proc definiton with procParameters and empty block" in { - val cpg = - code("-> (x,y) {}") - cpg.parameter.size shouldBe 2 - } - - "have correct structure for proc definiton with procParameters and non-empty block" in { - val cpg = - code("""-> (x,y) { - |if (x) - | y - |else - | b - |end - |}""".stripMargin) - cpg.parameter.size shouldBe 2 - val List(paramOne, paramTwo) = cpg.parameter.l - paramOne.name shouldBe "x" - paramTwo.name shouldBe "y" - cpg.ifBlock.size shouldBe 1 - } - - "have correct structure for proc definition with no parameters and empty block" in { - val cpg = code("-> {}") - cpg.parameter.size shouldBe 0 - } - - "have correct structure for proc definition with additional context" in { - val cpg = code( - "scope :get_all_doctors, -> { (select('id, first_name').where('role = :user_role', user_role: User.roles[:doctor])) }" - ) - cpg.parameter.size shouldBe 6 - cpg.call.name("proc_2").size shouldBe 1 - cpg.call.name("scope").size shouldBe 1 - cpg.call.name("where").size shouldBe 1 - cpg.call.name("select").size shouldBe 1 - cpg.call.name("roles").size shouldBe 1 - cpg.call.name(".activeRecordAssociation").size shouldBe 1 - cpg.call.name(".indexAccess").size shouldBe 1 - } - - "have correct structure when method called with safe navigation without parameters" in { - val cpg = code("foo&.bar") - cpg.call.size shouldBe 1 - } - - "have correct structure when method called with safe navigation with parameters with parantheses" in { - val cpg = code("foo&.bar(1)") - - val List(callNode) = cpg.call.l - val List(actualArg) = callNode.argument.argumentIndex(1).l - actualArg.code shouldBe "1" - cpg.argument.size shouldBe 2 - cpg.call.size shouldBe 1 - } - - "have correct structure when method called with safe navigation with parameters without parantheses" in { - val cpg = code("foo&.bar 1,2") - - val List(callNode) = cpg.call.l - val List(actualArg1) = callNode.argument.argumentIndex(1).l - actualArg1.code shouldBe "1" - val List(actualArg2) = callNode.argument.argumentIndex(2).l - actualArg2.code shouldBe "2" - cpg.argument.size shouldBe 3 - cpg.call.size shouldBe 1 - } - - "have correct structure when method call present in next line, with the second line starting with `.`" in { - val cpg = code("foo\n .bar(1)") - - val List(callNode) = cpg.call.l - cpg.call.size shouldBe 1 - callNode.code shouldBe ("foo\n .bar(1)") - callNode.name shouldBe "bar" - callNode.lineNumber shouldBe Some(2) - val List(actualArg) = callNode.argument.argumentIndex(1).l - actualArg.code shouldBe "1" - } - - "have correct structure when method call present in next line, with the first line ending with `.`" in { - val cpg = code("foo.\n bar(1)") - - val List(callNode) = cpg.call.l - cpg.call.size shouldBe 1 - callNode.code shouldBe ("foo.\n bar(1)") - callNode.name shouldBe "bar" - callNode.lineNumber shouldBe Some(1) - val List(actualArg) = callNode.argument.argumentIndex(1).l - actualArg.code shouldBe "1" - } - - "have correct structure for proc parameter with name" in { - val cpg = code("def foo(&block) end") - val List(actualParameter) = cpg.method("foo").parameter.l - actualParameter.name shouldBe "block" - } - - "have correct structure for proc parameter with no name" in { - val cpg = code("def foo(&) end") - val List(actualParameter) = cpg.method("foo").parameter.l - actualParameter.name shouldBe "param_0" - } - - "have correct structure when regular expression literal passed after `when`" in { - val cpg = code(""" - |case foo - | when /^ch/ - | bar - |end - |""".stripMargin) - - val List(literalArg) = cpg.literal.l - literalArg.typeFullName shouldBe Defines.Regexp - literalArg.code shouldBe "/^ch/" - literalArg.lineNumber shouldBe Some(3) - } - - "have correct structure when have interpolated double-quoted string literal" in { - val cpg = code(""" - |v = :"w x #{y} z" - |""".stripMargin) - - cpg.call.size shouldBe 4 - cpg.call.name(".formatString").head.code shouldBe """:"w x #{y} z"""" - cpg.call.name(".formatValue").head.code shouldBe "#{y}" - - cpg.literal.size shouldBe 2 - cpg.literal.code("w x ").size shouldBe 1 - cpg.literal.code(" z").size shouldBe 1 - - cpg.identifier.name("y").size shouldBe 1 - cpg.identifier.name("v").size shouldBe 1 - } - - "have correct structure when have non-interpolated double-quoted string literal" in { - val cpg = code(""" - |x = :"y z" - |""".stripMargin) - - cpg.call.size shouldBe 1 - val List(literal) = cpg.literal.l - literal.code shouldBe ":\"y z\"" - literal.typeFullName shouldBe Defines.Symbol - } - - "have correct structure when have symbol " in { - val cpg = code(s""" - |x = :"${10}" - |""".stripMargin) - - cpg.call.size shouldBe 1 - val List(literal) = cpg.literal.l - literal.typeFullName shouldBe Defines.Symbol - literal.code shouldBe ":\"10\"" - } - } - - "have correct structure when no RHS for a mandatory parameter is provided" in { - val cpg = code(""" - |def foo(bar:) - |end - |""".stripMargin) - - val List(parameterNode) = cpg.method("foo").parameter.l - parameterNode.name shouldBe "bar" - parameterNode.lineNumber shouldBe Some(2) - } - - "have correct structure when RHS for a mandatory parameter is provided" in { - val cpg = code(""" - |def foo(bar: world) - |end - |""".stripMargin) - - val List(parameterNode) = cpg.method("foo").parameter.l - parameterNode.name shouldBe "bar" - parameterNode.lineNumber shouldBe Some(2) - } - - // Change below test cases to focus on the argument of call `foo` - "have correct structure when a association is passed as an argument with parantheses" in { - val cpg = code("""foo(bar:)""".stripMargin) - - cpg.argument.size shouldBe 2 - cpg.argument.l(0).code shouldBe "bar:" - cpg.call.size shouldBe 2 - val List(callNode, operatorNode) = cpg.call.l - callNode.name shouldBe "foo" - operatorNode.name shouldBe ".activeRecordAssociation" - } - - "have correct structure when a association is passed as an argument without parantheses" in { - val cpg = code("""foo bar:""".stripMargin) - - cpg.argument.size shouldBe 2 - cpg.argument.l.head.code shouldBe "bar:" - - cpg.call.size shouldBe 2 - val List(callNode, operatorNode) = cpg.call.l - callNode.name shouldBe "foo" - operatorNode.name shouldBe ".activeRecordAssociation" - } - - "have correct structure with ternary operator with multiple line" in { - val cpg = code("""x = a ? - | b - |: c""".stripMargin) - - val List(controlNode) = cpg.controlStructure.l - controlNode.controlStructureType shouldBe ControlStructureTypes.IF - controlNode.code shouldBe "a ?\n b\n: c" - controlNode.lineNumber shouldBe Some(1) - controlNode.columnNumber shouldBe Some(4) - - val List(a) = controlNode.condition.isIdentifier.l - a.code shouldBe "a" - a.name shouldBe "a" - a.lineNumber shouldBe Some(1) - a.columnNumber shouldBe Some(4) - - val List(_, b, c) = controlNode.astChildren.isIdentifier.l - b.code shouldBe "b" - b.name shouldBe "b" - b.lineNumber shouldBe Some(2) - b.columnNumber shouldBe Some(1) - - c.code shouldBe "c" - c.name shouldBe "c" - c.lineNumber shouldBe Some(3) - c.columnNumber shouldBe Some(2) - } - - "have correct structure for blank indexing arguments" in { - val cpg = code(""" - |bar = Set[] - |""".stripMargin) - - val List(callNode) = cpg.call.name(".indexAccess").l - callNode.lineNumber shouldBe Some(2) - callNode.columnNumber shouldBe Some(9) - } - - "method defined inside a class using << operator" in { - val cpg = code(""" - class MyClass - | - | class << self - | def print - | puts "log #{self}" - | end - | end - | class << self - | end - |end - | - |MyClass.print""".stripMargin) - - val List(callNode) = cpg.call.name("print").l - callNode.lineNumber shouldBe Some(13) - callNode.columnNumber shouldBe Some(7) - callNode.name shouldBe "print" - } - - "have correct structure for body statements inside a do block" in { - val cpg = code(""" - |def foo - |1/0 - |rescue ZeroDivisionError => e - |end""".stripMargin) - - val List(methodNode) = cpg.method.code(".*foo.*").l - methodNode.name shouldBe "foo" - methodNode.lineNumber shouldBe Some(2) - - val List(assignmentOperator, divisionOperator) = cpg.method.name(".*operator.*").l - divisionOperator.name shouldBe ".division" - assignmentOperator.name shouldBe ".assignment" - } - - "have correct structure when regex literal is used on RHS of association" in { - val cpg = code(""" - |books = [ - | { - | id: /.*/ - | } - |] - |""".stripMargin) - - val List(assocOperator) = cpg.call(".*activeRecordAssociation.*").l - assocOperator.code shouldBe "id: /.*/" - assocOperator.astChildren.code.l(1) shouldBe "/.*/" - assocOperator.lineNumber shouldBe Some(4) - } - - "have double-quoted string literals containing \\u character" in { - val cpg = code(""" - |val fileName = "AB\u0003\u0004\u0014\u0000\u0000\u0000\b\u0000\u0000\u0000!\u0000file" - |""".stripMargin) - - cpg.identifier.size shouldBe 1 - cpg.identifier.name.head shouldBe "fileName" - cpg.literal.head.code - .stripPrefix("\"") - .stripSuffix("\"") - .trim shouldBe """AB\u0003\u0004\u0014\u0000\u0000\u0000\b\u0000\u0000\u0000!\u0000file""" - - } - - "have correct structure for a endless method" in { - val cpg = code(""" - |def foo(a,b) = a*b - |""".stripMargin) - - val List(methodNode) = cpg.method.name("foo").l - methodNode.lineNumber shouldBe Some(2) - methodNode.columnNumber shouldBe Some(4) - } - - "have correct structure for symbol literal defined using \\:" in { - val cpg = code(""" - |foo = {:bar=>zoo} - |""".stripMargin) - - val List(keyValueAssocOperator) = cpg.call(".*keyValueAssociation.*").l - keyValueAssocOperator.code shouldBe ":bar=>zoo" - keyValueAssocOperator.astChildren.l.head.code shouldBe ":bar" - keyValueAssocOperator.astChildren.l(1).code shouldBe "zoo" - } - - "having a binary expression includes + and @" in { - val cpg = code(""" - |class MyClass - | def initialize(a) - | @a = a - | end - | - | def calculate_x(b) - | x = b+@a - | return x - | end - |end - |""".stripMargin) - cpg.identifier("a").dedup.size shouldBe 1 - cpg.identifier("b").dedup.size shouldBe 1 - cpg.identifier("x").name.dedup.size shouldBe 1 - cpg.method("calculate_x").size shouldBe 1 - } - - "have correct structure for empty %w array" in { - val cpg = code(""" - |a = %w[] - |""".stripMargin) - - val List(assignmentCallNode) = cpg.call.name(Operators.assignment).l - assignmentCallNode.size shouldBe 1 - val List(arrayCallNode) = cpg.call.name(Operators.arrayInitializer).l - arrayCallNode.size shouldBe 1 - arrayCallNode.argument.size shouldBe 0 - } - - "have correct structure for %w array with %w()" in { - val cpg = code(""" - |a = %w(b c d) - |""".stripMargin) - - val List(assignmentCallNode) = cpg.call.name(Operators.assignment).l - assignmentCallNode.size shouldBe 1 - val List(arrayCallNode) = cpg.call.name(Operators.arrayInitializer).l - arrayCallNode.size shouldBe 1 - arrayCallNode.argument - .where(_.argumentIndex(1)) - .code - .l shouldBe List("b") - arrayCallNode.argument - .where(_.argumentIndex(2)) - .code - .l shouldBe List("c") - arrayCallNode.argument - .where(_.argumentIndex(3)) - .code - .l shouldBe List("d") - } - - "have correct structure for %w array with %w() with entries separated by whitespace" in { - val cpg = code(""" - |a = %w( - | bob - | cod - | dod - |) - |""".stripMargin) - - val List(assignmentCallNode) = cpg.call.name(Operators.assignment).l - assignmentCallNode.size shouldBe 1 - val List(arrayCallNode) = cpg.call.name(Operators.arrayInitializer).l - arrayCallNode.size shouldBe 1 - arrayCallNode.argument - .where(_.argumentIndex(1)) - .code - .l shouldBe List("bob") - arrayCallNode.argument - .where(_.argumentIndex(2)) - .code - .l shouldBe List("cod") - arrayCallNode.argument - .where(_.argumentIndex(3)) - .code - .l shouldBe List("dod") - } - - "have correct structure for %w array with %w- -" in { - val cpg = code(""" - |a = %w-b c- - |""".stripMargin) - - val List(assignmentCallNode) = cpg.call.name(Operators.assignment).l - assignmentCallNode.size shouldBe 1 - val List(arrayCallNode) = cpg.call.name(Operators.arrayInitializer).l - arrayCallNode.size shouldBe 1 - arrayCallNode.argument - .where(_.argumentIndex(1)) - .code - .l shouldBe List("b") - arrayCallNode.argument - .where(_.argumentIndex(2)) - .code - .l shouldBe List("c") - } - - "have correct structure for %i() array with two elements" in { - val cpg = code("x = %i(yy zz)") - - val List(arrayInit) = cpg.call.name(Operators.arrayInitializer).l - val List(yyNode: Literal, zzNode: Literal) = arrayInit.argument.isLiteral.l - - yyNode.code shouldBe "yy" - yyNode.argumentIndex shouldBe 1 - yyNode.typeFullName shouldBe Defines.Symbol - - zzNode.code shouldBe "zz" - zzNode.argumentIndex shouldBe 2 - zzNode.typeFullName shouldBe Defines.Symbol - } - - "have correct structure parenthesised arguments in a return jump" in { - val cpg = code("""return(value) unless item""".stripMargin) - - cpg.identifier.size shouldBe 2 - cpg.identifier.name("value").size shouldBe 1 - cpg.identifier.name("item").size shouldBe 1 - - val List(methodReturn) = cpg.ret.l - methodReturn.code shouldBe "return(value)" - methodReturn.lineNumber shouldBe Some(1) - methodReturn.columnNumber shouldBe Some(0) - } - - "have correct structure for a hash containing splatting elements" in { - val cpg = code(""" - |bar={:x=>1} - |foo = { - |**bar - |} - |""".stripMargin) - - val List(keyValueAssocOperator) = cpg.call(".*keyValueAssociation.*").l - keyValueAssocOperator.code shouldBe ":x=>1" - keyValueAssocOperator.astChildren.l(1).code shouldBe "1" - - val List(pseudoIdentifier, actualIdentifier) = cpg.identifier("bar").l - pseudoIdentifier.lineNumber shouldBe Some(2) - pseudoIdentifier.columnNumber shouldBe Some(0) - } - - "have correct structure for regex match global variables" in { - val cpg = code(""" - |content_filename =~ /filename="(.*)"/ - |value = $1 - |""".stripMargin) - - cpg.call.size shouldBe 2 - cpg.call.code(".*filename.*").head.methodFullName shouldBe ".patternMatch" - - cpg.identifier.code("value").size shouldBe 1 - cpg.identifier.name("\\$1").size shouldBe 1 - cpg.identifier.name("\\$1").head.typeFullName shouldBe Defines.String - - cpg.literal.code("/filename=\"(.*)\"/").head.typeFullName shouldBe Defines.Regexp - } - - "have correct structure of unless keyword and regex statement" in { - val cpg = code("""def contains_numbers?(string) - | # Define a regular expression pattern to match any digit - | regex_pattern = /\d/ - | - | # Check if the string contains any numbers using the 'unless' keyword - | unless string.match(regex_pattern).nil? - | return true - | end - | - | return false - |end""".stripMargin) - - cpg.identifier.code("regex_pattern").name.dedup.size shouldBe 1 - cpg.method("contains_numbers\\?").name.size shouldBe 1 - cpg.call(".assignment").name.size shouldBe 2 - cpg.call("match").name.size shouldBe 1 - } - - "have correct structure for association identifier" in { - val cpg = code(""" - |foo(a:b) - |""".stripMargin) - - cpg.call.size shouldBe 2 - cpg.call.name(".activeRecordAssociation").size shouldBe 1 - - cpg.identifier.size shouldBe 2 - cpg.identifier.name("a").size shouldBe 1 - cpg.identifier.name("b").size shouldBe 1 - } - - "have correct structure for multiline %i" in { - val cpg = code(""" - |w = %i(x y - | z) - |""".stripMargin) - - val List(arrayInit) = cpg.call.name(Operators.arrayInitializer).l - val List(xNode: Literal, yNode: Literal, zNode: Literal) = arrayInit.argument.isLiteral.l - - xNode.code shouldBe "x" - xNode.argumentIndex shouldBe 1 - xNode.typeFullName shouldBe Defines.Symbol - - yNode.code shouldBe "y" - yNode.argumentIndex shouldBe 2 - yNode.typeFullName shouldBe Defines.Symbol - - zNode.code shouldBe "z" - zNode.argumentIndex shouldBe 3 - zNode.typeFullName shouldBe Defines.Symbol - } - - "have correct structure for packing LHS in multiple assignment" in { - val cpg = code(""" - |some_lhs,*pack_lhs = some_rhs,pack_rhs1, pack_rhs2 - |""".stripMargin) - - cpg.call.name(".assignment").size shouldBe 2 - val callNode1 = cpg.call.code("some_lhs = some_rhs").l.head - callNode1.lineNumber shouldBe Some(2) - callNode1.columnNumber shouldBe Some(19) - val args1 = callNode1.argument.l - args1.size shouldBe 2 - args1.head.code shouldBe "some_lhs" - args1.tail.code.l.head shouldBe "some_rhs" - - val callNode2 = cpg.call.code("pack_lhs = pack_rhs1, pack_rhs2").l.head - callNode2.lineNumber shouldBe Some(2) - callNode2.columnNumber shouldBe Some(19) - val args2 = callNode2.argument.l - args2.size shouldBe 2 - args2.head.code shouldBe "pack_lhs" - args2.tail.code.l.head shouldBe "pack_rhs1, pack_rhs2" - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/TypeDeclAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/TypeDeclAstCreationPassTest.scala deleted file mode 100644 index 0eea15e938c8..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/TypeDeclAstCreationPassTest.scala +++ /dev/null @@ -1,290 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.semanticcpg.language.* -import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.codepropertygraph.generated.Operators - -class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "AST generation for simple classes declarations" should { - - "generate a basic type declaration node for an empty class" in { - val cpg = code(""" - |class MyClass - |end - |""".stripMargin) - val List(myClass) = cpg.typeDecl.nameExact("MyClass").l - myClass.name shouldBe "MyClass" - myClass.fullName shouldBe "Test0.rb::program.MyClass" - } - - // TODO: Need to be fixed. - "generate a basic type declaration node for an empty class with Class.new" ignore { - val cpg = code(""" - |MyClass = Class.new do - |end - |""".stripMargin) - val List(myClass) = cpg.typeDecl.nameExact("MyClass").l - myClass.name shouldBe "MyClass" - myClass.fullName shouldBe "Test0.rb::program.MyClass" - } - - // TODO: Need to be fixed. - "populate class name correctly for a derived class" in { - val cpg = code(""" - |module ApplicationCable - | class Channel < ActionCable::Channel::Base - | end - |end - |""".stripMargin) - val List(myClass) = cpg.typeDecl.nameExact("Channel").l - myClass.name shouldBe "Channel" - myClass.fullName shouldBe "Test0.rb::program.ApplicationCable.Channel" - } - - "generate methods under type declarations" in { - val cpg = code(""" - |class Vehicle - | - | def self.speeding - | "Hello, from a class method" - | end - | - | def Vehicle.halting - | "Hello, from another class method" - | end - | - | def driving - | "Hello, from an instance method" - | end - | - |end - |""".stripMargin) - val List(vehicle) = cpg.typeDecl.nameExact("Vehicle").l - vehicle.name shouldBe "Vehicle" - vehicle.fullName shouldBe "Test0.rb::program.Vehicle" - - val List(_, speeding, halting, driving) = vehicle.method.l - speeding.name shouldBe "speeding" - halting.name shouldBe "halting" - driving.name shouldBe "driving" - - speeding.fullName shouldBe "Test0.rb::program.Vehicle.speeding" - halting.fullName shouldBe "Test0.rb::program.Vehicle.halting" - driving.fullName shouldBe "Test0.rb::program.Vehicle.driving" - } - - "generate members for various class members under the respective type declaration" in { - val cpg = code(""" - |class Song - | @@plays = 0 - | def initialize(name, artist, duration) - | @name = name - | @artist = artist - | @duration = duration - | end - |end - |""".stripMargin) - val List(song) = cpg.typeDecl.nameExact("Song").l - song.name shouldBe "Song" - song.fullName shouldBe "Test0.rb::program.Song" - - val List(classInit) = song.method.name(XDefines.StaticInitMethodName).l - classInit.fullName shouldBe s"Test0.rb::program.Song.${XDefines.StaticInitMethodName}" - val List(playsDef) = classInit.call.nameExact(Operators.fieldAccess).fieldAccess.l - playsDef.fieldIdentifier.canonicalName.headOption shouldBe Option("plays") - - val List(artist, duration, name, plays) = song.member.l - - plays.name shouldBe "plays" - name.name shouldBe "name" - artist.name shouldBe "artist" - duration.name shouldBe "duration" - - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("plays", "name", "artist", "duration") - } - - "generate members for various class members when using the `attr_reader` and `attr_writer` idioms" ignore { - val cpg = code(""" - |class Song - | attr_reader :name, :artist, :duration - | attr_writer :album - |end - |""".stripMargin) - val List(song) = cpg.typeDecl.nameExact("Song").l - song.name shouldBe "Song" - song.fullName shouldBe "Test0.rb::program.Song" - - val List(name, artist, duration, album) = song.member.l - name.name shouldBe "name" - artist.name shouldBe "artist" - duration.name shouldBe "duration" - album.name shouldBe "album" - } - - "generate methods with the correct access control modifiers case 1" in { - val cpg = code(""" - |class MyClass - | - | def method1 # default is 'public' - | #... - | end - | - | protected # subsequent methods will be 'protected' - | - | def method2 # will be 'protected' - | #... - | end - | - | private # subsequent methods will be 'private' - | - | def method3 # will be 'private' - | #... - | end - | - | public # subsequent methods will be 'public' - | - | def method4 # and this will be 'public' - | #... - | end - |end - |""".stripMargin) - val List(myClass) = cpg.typeDecl.nameExact("MyClass").l - myClass.name shouldBe "MyClass" - myClass.fullName shouldBe "Test0.rb::program.MyClass" - - val List(_, _, m1, m2, m3, m4) = myClass.method.l - m1.name shouldBe "method1" - m2.name shouldBe "method2" - m3.name shouldBe "method3" - m4.name shouldBe "method4" - - m1.fullName shouldBe "Test0.rb::program.MyClass.method1" - m2.fullName shouldBe "Test0.rb::program.MyClass.method2" - m3.fullName shouldBe "Test0.rb::program.MyClass.method3" - m4.fullName shouldBe "Test0.rb::program.MyClass.method4" - - m1.modifier.modifierType.l shouldBe List(ModifierTypes.PUBLIC) - m2.modifier.modifierType.l shouldBe List(ModifierTypes.PROTECTED) - m3.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) - m4.modifier.modifierType.l shouldBe List(ModifierTypes.PUBLIC) - } - - "generate methods with the correct access control modifiers case 2" ignore { - val cpg = code(""" - |class MyClass - | - | def method1 - | end - | - | def method2 - | end - | - | def method3 - | end - | - | def method4 - | end - | - | public :method1, :method4 - | protected :method2 - | private :method3 - |end - |""".stripMargin) - val List(myClass) = cpg.typeDecl.nameExact("MyClass").l - myClass.name shouldBe "MyClass" - myClass.fullName shouldBe "Test0.rb::program.MyClass" - - val List(m1, m2, m3, m4) = myClass.method.l - m1.name shouldBe "method1" - m2.name shouldBe "method2" - m3.name shouldBe "method3" - m4.name shouldBe "method4" - - m1.fullName shouldBe "Test0.rb::program.MyClass.method1" - m2.fullName shouldBe "Test0.rb::program.MyClass.method2" - m3.fullName shouldBe "Test0.rb::program.MyClass.method3" - m4.fullName shouldBe "Test0.rb::program.MyClass.method4" - - m1.modifier.modifierType.l shouldBe List(ModifierTypes.PUBLIC) - m2.modifier.modifierType.l shouldBe List(ModifierTypes.PROTECTED) - m3.modifier.modifierType.l shouldBe List(ModifierTypes.PRIVATE) - m4.modifier.modifierType.l shouldBe List(ModifierTypes.PUBLIC) - } - - } - - "Polymorphism in classes" should { - - "correctly contain the inherited base type name in the super type" ignore { - val cpg = code(""" - |class GeeksforGeeks - | def initialize - | puts "This is Superclass" - | end - | - | def super_method - | puts "Method of superclass" - | end - |end - | - |class Sudo_Placement < GeeksforGeeks - | def initialize - | puts "This is Subclass" - | end - |end - |""".stripMargin) - - val List(baseType) = cpg.typeDecl.nameExact("GeeksforGeeks").l - baseType.name shouldBe "GeeksforGeeks" - baseType.fullName shouldBe "Test0.rb::program.GeeksforGeeks" - - val List(subType) = cpg.typeDecl.nameExact("Sudo_Placement").l - subType.name shouldBe "Sudo_Placement" - subType.fullName shouldBe "Test0.rb::program.Sudo_Placement" - subType.inheritsFromTypeFullName shouldBe Seq("Test0.rb::program.GeeksforGeeks") - } - - } - - "Hierarchical class checks with constants" ignore { - val cpg = code(""" - |class MyClass - | MY_CONSTANT = 0 - | @@plays = 0 - | class ChildCls - | @@name = 0 - | MY_CONSTANT = 0 - | end - |end - |""".stripMargin) - - "member variables structure in place" in { - val List(clsInit1, clsInit2) = cpg.method(XDefines.StaticInitMethodName).l - clsInit1.fullName shouldBe s"Test0.rb::program.MyClass.${XDefines.StaticInitMethodName}" - val List(myconstantfa, playsfa) = clsInit1.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstantfa.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - playsfa.fieldIdentifier.canonicalName.headOption shouldBe Option("plays") - - clsInit2.fullName shouldBe s"Test0.rb::program.MyClass.ChildCls.${XDefines.StaticInitMethodName}" - val List(namefa, myconstant2fa) = clsInit2.call.nameExact(Operators.fieldAccess).fieldAccess.l - myconstant2fa.fieldIdentifier.canonicalName.headOption shouldBe Option("MY_CONSTANT") - namefa.fieldIdentifier.canonicalName.headOption shouldBe Option("name") - - val List(myclassTd2) = cpg.typeDecl("ChildCls").l - val List(namem, myConstant2m) = myclassTd2.member.l - myConstant2m.name shouldBe "MY_CONSTANT" - namem.name shouldBe "name" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("name", "MY_CONSTANT") - - val List(myclassTd) = cpg.typeDecl("MyClass").l - val List(myconstantm, playsm) = myclassTd.member.l - myconstantm.name shouldBe "MY_CONSTANT" - playsm.name shouldBe "plays" - cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("MY_CONSTANT", "plays") - - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/UnaryOpCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/UnaryOpCpgTests.scala deleted file mode 100644 index 8cf301a99371..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/ast/UnaryOpCpgTests.scala +++ /dev/null @@ -1,47 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.ast - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -import io.shiftleft.semanticcpg.language.* - -class UnaryOpCpgTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - "#unaryOp not" should { - val cpg = code("""!true""".stripMargin) - "test unaryOp 'not' call node properties" in { - val plusCall = cpg.call.methodFullName(Operators.not).head - plusCall.code shouldBe "!true" - plusCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - plusCall.lineNumber shouldBe Some(1) - } - - "test unaryOp 'not' arguments" in { - cpg.call - .methodFullName(Operators.not) - .argument - .argumentIndex(1) - .isLiteral - .head - .code shouldBe "true" - } - } - - "#unaryOp invert" should { - val cpg = code("""~2""".stripMargin) - "test unaryOp 'invert' call node properties" in { - val plusCall = cpg.call.methodFullName(Operators.not).head - plusCall.code shouldBe "~2" - plusCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - plusCall.lineNumber shouldBe Some(1) - } - - "test unaryOp 'invert' arguments" in { - cpg.call - .methodFullName(Operators.not) - .argument - .argumentIndex(1) - .isLiteral - .head - .code shouldBe "2" - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/cfg/SimpleCfgCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/cfg/SimpleCfgCreationPassTest.scala deleted file mode 100644 index 4e72b25ef4bd..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/passes/cfg/SimpleCfgCreationPassTest.scala +++ /dev/null @@ -1,132 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.passes.cfg - -import io.joern.rubysrc2cpg.testfixtures.RubyCfgTestCpg -import io.joern.x2cpg.passes.controlflow.cfgcreation.Cfg.AlwaysEdge -import io.joern.x2cpg.testfixtures.CfgTestFixture -import io.shiftleft.codepropertygraph.generated.Cpg - -class SimpleCfgCreationPassTest extends CfgTestFixture(() => new RubyCfgTestCpg(useDeprecatedFrontend = true)) { - - "CFG generation for simple fragments" should { - "have correct structure for empty array literal" ignore { - implicit val cpg: Cpg = code("x = []") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("x = []", AlwaysEdge)) - succOf("x = []") shouldBe expected(("RET", AlwaysEdge)) - } - - "have correct structure for array literal with values" in { - implicit val cpg: Cpg = code("x = [1, 2]") - succOf("1") shouldBe expected(("2", AlwaysEdge)) - succOf("x = [1, 2]") shouldBe expected(("RET", AlwaysEdge)) - } - - "assigning a literal value" in { - implicit val cpg: Cpg = code("x = 1") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x = 1") shouldBe expected(("RET", AlwaysEdge)) - } - - "assigning a string literal value" in { - implicit val cpg: Cpg = code("x = 'some literal'") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x = 'some literal'") shouldBe expected(("RET", AlwaysEdge)) - } - - "addition of two numbers" in { - implicit val cpg: Cpg = code("x = 1 + 2") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x = 1 + 2") shouldBe expected(("RET", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("2") shouldBe expected(("1 + 2", AlwaysEdge)) - succOf("1") shouldBe expected(("2", AlwaysEdge)) - succOf("1 + 2") shouldBe expected(("x = 1 + 2", AlwaysEdge)) - } - - "addition of two string" in { - implicit val cpg: Cpg = code("x = 1 + 2") - succOf(":program") shouldBe expected(("x", AlwaysEdge)) - succOf("x = 1 + 2") shouldBe expected(("RET", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("2") shouldBe expected(("1 + 2", AlwaysEdge)) - succOf("1") shouldBe expected(("2", AlwaysEdge)) - succOf("1 + 2") shouldBe expected(("x = 1 + 2", AlwaysEdge)) - } - - "addition of multiple string" in { - implicit val cpg: Cpg = code(""" - |a = "Nice to meet you" - |b = ", " - |c = "do you like blueberries?" - |a+b+c - |""".stripMargin) - succOf(":program") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("\"Nice to meet you\"", AlwaysEdge)) - succOf("b") shouldBe expected(("\", \"", AlwaysEdge)) - succOf("c") shouldBe expected(("\"do you like blueberries?\"", AlwaysEdge)) - succOf("a+b+c") shouldBe expected(("RET", AlwaysEdge)) - succOf("a+b") shouldBe expected(("c", AlwaysEdge)) - succOf("\"Nice to meet you\"") shouldBe expected(("a = \"Nice to meet you\"", AlwaysEdge)) - succOf("\", \"") shouldBe expected(("b = \", \"", AlwaysEdge)) - succOf("\"do you like blueberries?\"") shouldBe expected(("c = \"do you like blueberries?\"", AlwaysEdge)) - } - - "addition of multiple string and assign to variable" in { - implicit val cpg: Cpg = code(""" - |a = "Nice to meet you" - |b = ", " - |c = "do you like blueberries?" - |x = a+b+c - |""".stripMargin) - succOf(":program") shouldBe expected(("a", AlwaysEdge)) - succOf("a") shouldBe expected(("\"Nice to meet you\"", AlwaysEdge)) - succOf("b") shouldBe expected(("\", \"", AlwaysEdge)) - succOf("c") shouldBe expected(("\"do you like blueberries?\"", AlwaysEdge)) - succOf("a+b+c") shouldBe expected(("x = a+b+c", AlwaysEdge)) - succOf("a+b") shouldBe expected(("c", AlwaysEdge)) - succOf("\"Nice to meet you\"") shouldBe expected(("a = \"Nice to meet you\"", AlwaysEdge)) - succOf("\", \"") shouldBe expected(("b = \", \"", AlwaysEdge)) - succOf("\"do you like blueberries?\"") shouldBe expected(("c = \"do you like blueberries?\"", AlwaysEdge)) - succOf("x") shouldBe expected(("a", AlwaysEdge)) - } - - "single hierarchy of if else statement" in { - implicit val cpg: Cpg = code(""" - |x = 1 - |if x > 2 - | puts "x is greater than 2" - |end - |""".stripMargin) - succOf(":program") shouldBe expected(("puts", AlwaysEdge)) - succOf("puts") shouldBe expected(("__builtin.puts", AlwaysEdge)) - succOf("__builtin.puts") shouldBe expected(("puts = __builtin.puts", AlwaysEdge)) - succOf("puts = __builtin.puts") shouldBe expected(("x", AlwaysEdge)) - succOf("1") shouldBe expected(("x = 1", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("2") shouldBe expected(("x > 2", AlwaysEdge)) - } - - "multiple hierarchy of if else statement" in { - implicit val cpg: Cpg = code(""" - |x = 1 - |if x > 2 - | puts "x is greater than 2" - |elsif x <= 2 and x!=0 - | puts "x is 1" - |else - | puts "I can't guess the number" - |end - |""".stripMargin) - succOf(":program") shouldBe expected(("puts", AlwaysEdge)) - succOf("puts") shouldBe expected(("__builtin.puts", AlwaysEdge)) - succOf("__builtin.puts") shouldBe expected(("puts = __builtin.puts", AlwaysEdge)) - succOf("puts = __builtin.puts") shouldBe expected(("x", AlwaysEdge)) - succOf("1") shouldBe expected(("x = 1", AlwaysEdge)) - succOf("x") shouldBe expected(("1", AlwaysEdge)) - succOf("2") shouldBe expected(("x > 2", AlwaysEdge)) - succOf("x <= 2 and x!=0") subsetOf expected(("\"x is 1\"", AlwaysEdge)) - succOf("x <= 2 and x!=0") subsetOf expected(("RET", AlwaysEdge)) - } - - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/AssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/AssignmentTests.scala deleted file mode 100644 index 7ad9ad37ad09..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/AssignmentTests.scala +++ /dev/null @@ -1,87 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language.* - -class AssignmentTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "CPG for code with identifiers and literals in simple assignments" should { - val cpg = code(""" - |# call instance methods - |a = 1 - |b = 2 - |a = 3 - |b = 4 - |c = a*b - |puts "Multiplication is : #{c}" - |""".stripMargin) - - "recognize all assignment nodes" in { - cpg.assignment.size shouldBe 6 // One assignment is for `puts = typeRef(__builtin.puts)` - } - - "have call nodes for .assignment as method name" in { - cpg.assignment.foreach { assignment => - assignment.name shouldBe Operators.assignment - assignment.methodFullName shouldBe Operators.assignment - } - } - - "should have identifiers as LHS for each assignment node" in { - cpg.call.nameExact(Operators.assignment).argument.where(_.argumentIndex(1)).foreach { idx => - idx.isIdentifier shouldBe true - } - } - - "recognise all identifier nodes" in { - cpg.identifier.name("a").size shouldBe 3 - cpg.identifier.name("b").size shouldBe 3 - cpg.identifier.name("c").size shouldBe 2 - } - - "recognise all literal nodes" in { - cpg.literal.code("1").size shouldBe 1 - cpg.literal.code("2").size shouldBe 1 - cpg.literal.code("3").size shouldBe 1 - cpg.literal.code("4").size shouldBe 1 - } - } - - "CPG for code with multiple assignments" should { - val cpg = code(""" - |a, b, c = [1, 2, 3] - |a, b, c = b, c, a - |str1, str2 = ["hello", "world"] - |p, q = [foo(), bar()] - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("a").size shouldBe 3 - cpg.identifier.name("b").size shouldBe 3 - cpg.identifier.name("c").size shouldBe 3 - cpg.identifier.name("str1").size shouldBe 1 - cpg.identifier.name("str2").size shouldBe 1 - cpg.identifier.name("p").size shouldBe 1 - cpg.identifier.name("q").size shouldBe 1 - } - - "recognise all literal nodes" in { - cpg.literal.code("1").size shouldBe 1 - cpg.literal.code("2").size shouldBe 1 - cpg.literal.code("3").size shouldBe 1 - cpg.literal.code("\"hello\"").size shouldBe 1 - cpg.literal.code("\"world\"").size shouldBe 1 - } - - "recognize call nodes in RHS" in { - cpg.call.codeExact("foo()").size shouldBe 1 - cpg.call.codeExact("bar()").size shouldBe 1 - } - - "recognise all assignment call nodes" in { - /* here we are also checking the synthetic assignment nodes for each element on both sides */ - cpg.call.name(Operators.assignment).size shouldBe 10 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/CallGraphTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/CallGraphTests.scala deleted file mode 100644 index ff46a8e689eb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/CallGraphTests.scala +++ /dev/null @@ -1,29 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.semanticcpg.language.* - -class CallGraphTests extends RubyCode2CpgFixture(withPostProcessing = true, useDeprecatedFrontend = true) { - - val cpg = code(""" - |def bar(content) - |puts content - |end - | - |def foo - |bar( 1 ) - |end - |""".stripMargin) - - "should identify call from `foo` to `bar`" in { - val List(callToBar) = cpg.call("bar").l - callToBar.name shouldBe "bar" - callToBar.methodFullName shouldBe "Test0.rb::program.bar" - callToBar.lineNumber shouldBe Some(7) - val List(bar: Method) = cpg.method("bar").internal.l - bar.fullName shouldBe callToBar.methodFullName - bar.caller.name.l shouldBe List("foo") - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/ControlStructureTests.scala deleted file mode 100644 index 12074751fcc6..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/ControlStructureTests.scala +++ /dev/null @@ -1,396 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.nodes.{Block, ControlStructure} -import io.shiftleft.semanticcpg.language.* -class ControlStructureTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "CPG for code with doBlock iterating over a constant array" should { - val cpg = code(""" - |[1, 2, "three"].each do |n| - | puts n - |end - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("n").size shouldBe 1 - cpg.identifier.size shouldBe 2 // 1 identifier node is for `puts = typeDef(__builtin.puts)` - } - - "recognize all call nodes" in { - cpg.call.name("each").size shouldBe 1 - cpg.call.name("puts").size shouldBe 1 - } - } - - "CPG for code iterating over hash discarding key using _" should { - val cpg = code(""" - |x.each do |_, y| - | puts y - |end - |""".stripMargin) - - "have a valid each call and method" in { - cpg.call("each").size shouldBe 1 - cpg.call("each").argument.where(_.isIdentifier).code.l shouldBe List("x") - } - - "have valid identifiers" in { - cpg.identifier.name("x").size shouldBe 1 - cpg.identifier.name("y").size shouldBe 1 - } - } - - "CPG for code with doBlock iterating over a constant array and multiple params" should { - val cpg = code(""" - |[1, 2, "three"].each do |n, m| - | expect { - | someObject.someMethod(n) - | someObject.someMethod(m) - | }.to otherMethod(n).by(1) - |end - | - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("n").size shouldBe 2 - cpg.identifier.name("m").size shouldBe 1 - cpg.identifier.size shouldBe 5 - cpg.method.name("fakeName").dotAst.l - } - - "recognize all call nodes" in { - cpg.call.name("each").size shouldBe 1 - cpg.call.name("someMethod").size shouldBe 2 - cpg.call.name("expect").size shouldBe 1 - cpg.call.name("to").size shouldBe 1 - cpg.call.name("otherMethod").size shouldBe 1 - cpg.call.name("by").size shouldBe 1 - } - } - - "CPG for code with return having an if statement" should { - val cpg = code(""" - |def some_method - | return if some_var - |end - | - |""".stripMargin) - - /* - * This code used jumpExpression. This validated t - */ - "recognise identifier nodes in the jump statement" in { - cpg.identifier.name("some_var").size shouldBe 1 - } - - "identify the control structure code" in { - cpg.controlStructure.code("return if some_var").size shouldBe 1 - } - } - - "CPG for code with yield" should { - val cpg = code(""" - |def yield_with_args_method - | yield 2*3 - | yield 100 - | yield - |end - | - |yield_with_args_method {|i| puts "arg is #{i}"} - | - |""".stripMargin) - - "recognise all method nodes" in { - cpg.method.name("yield_with_args_method").size shouldBe 1 - cpg.method.name("yield_with_args_method_yield").size shouldBe 1 - } - } - - "CPG for code with if/else condition" should { - val cpg = code(""" - |x = 1 - |if x > 2 - | puts "x is greater than 2" - |elsif x <= 2 and x!=0 - | puts "x is 1" - |else - | puts "I can't guess the number" - |end - | - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("x").size shouldBe 4 - } - - "recognize all literal nodes" in { - cpg.literal.code("1").size shouldBe 1 - cpg.literal.code("2").size shouldBe 2 - cpg.literal.code("0").size shouldBe 1 - cpg.literal.code("\"x is 1\"").size shouldBe 1 - cpg.literal.code("\"I can't guess the number\"").size shouldBe 1 - } - } - - "CPG for code with conditional operator" should { - val cpg = code(""" - |y = ( x > 2 ) ? x : x + 1 - |""".stripMargin) - - "recognise all literal and identifier nodes" in { - cpg.identifier.name("x").size shouldBe 3 - cpg.identifier.name("y").size shouldBe 1 - cpg.literal.code("1").size shouldBe 1 - } - } - - "CPG for code with unless condition" should { - val cpg = code(""" - |x = 1 - |unless x > 2 - | puts "x is less than or equal to 2" - |else - | puts "x is greater than 2" - |end - | - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.identifier.name("x").size shouldBe 2 - cpg.literal.code("2").size shouldBe 1 - cpg.literal.code("\"x is less than or equal to 2\"").size shouldBe 1 - cpg.literal.code("\"x is greater than 2\"").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 2 - } - } - - "CPG for code with case statement and case argument" should { - val cpg = code(""" - |choice = "5" - |case choice - |when "1","2" - | puts "1 or 2" - |when "3","4" - | puts "3 or 4" - |when "5","6" - | puts "5 or 6" - |when "7","8" - | puts "7 or 8" - |else - | "No match" - |end - | - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.identifier.name("choice").size shouldBe 2 - cpg.literal.code("\"1\"").size shouldBe 1 - cpg.literal.code("\"2\"").size shouldBe 1 - cpg.literal.code("\"3\"").size shouldBe 1 - cpg.literal.code("\"4\"").size shouldBe 1 - cpg.literal.code("\"5\"").size shouldBe 2 - cpg.literal.code("\"6\"").size shouldBe 1 - cpg.literal.code("\"7\"").size shouldBe 1 - cpg.literal.code("\"8\"").size shouldBe 1 - cpg.literal.code("\"1 or 2\"").size shouldBe 1 - cpg.literal.code("\"3 or 4\"").size shouldBe 1 - cpg.literal.code("\"5 or 6\"").size shouldBe 1 - cpg.literal.code("\"7 or 8\"").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 4 - } - } - - "CPG for code with case statement and no case" should { - val cpg = code(""" - |str = "some_string" - | - |case - |when str.match('/\d/') - | puts 'String contains numbers' - |when str.match('/[a-zA-Z]/') - | puts 'String contains letters' - |else - | puts 'String does not contain numbers & letters' - |end - | - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.identifier.name("str").size shouldBe 3 - cpg.literal.code("\"some_string\"").size shouldBe 1 - cpg.literal.code("'String contains numbers'").size shouldBe 1 - cpg.literal.code("'String contains letters'").size shouldBe 1 - cpg.literal.code("'String does not contain numbers & letters'").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 3 - } - } - - "CPG for code with a while loop" should { - val cpg = code(""" - |x = 10 - |while x >= 1 - | x = x - 1 - | puts "In the loop" - |end - |""".stripMargin) - - "recognise all method nodes" in { - cpg.identifier - .name("x") - .size shouldBe 4 // FIXME this shows as 3 when the puts is the first loop statemnt. Find why - cpg.literal.code("\"In the loop\"").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 1 - } - } - - "CPG for code with a until loop" should { - val cpg = code(""" - |x = 10 - |until x == 0 - | puts "In the loop" - | x = x - 1 - |end - |""".stripMargin) - - "recognise all method nodes" in { - cpg.identifier.name("x").size shouldBe 4 - cpg.literal.code("\"In the loop\"").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 1 - } - - "recognise `until` as a `while` control structure" in { - val List(controlStructure) = cpg.whileBlock.l - controlStructure.lineNumber shouldBe Some(3) - - val List(condition) = controlStructure.astChildren.isCall.l - condition.code shouldBe "x == 0" - condition.lineNumber shouldBe Some(3) - - val List(body) = controlStructure.astChildren.isBlock.l - val List(puts, assignment) = body.astChildren.l - puts.code shouldBe "puts \"In the loop\"" - puts.lineNumber shouldBe Some(4) - assignment.lineNumber shouldBe Some(5) - assignment.assignment.size shouldBe 1 - } - - } - - "CPG for code with a for loop" should { - val cpg = code(""" - |for x in 1..10 do - | puts x - |end - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.identifier.name("x").size shouldBe 2 - cpg.literal.code("1").size shouldBe 1 - cpg.literal.code("10").size shouldBe 1 - - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 1 - } - } - - "CPG for code with modifier statements" should { - val cpg = code(""" - |for i in 1..10 - | next if i % 2 == 0 - | redo if i > 8 - | retry if i > 7 - | puts i if i == 9 - | i += 4 unless i > 5 - | - | value1 = 0 - | value1 += 1 while value1 < 100 - | - | value2 = 0 - | value2 += 1 until value2 >= 100 - |end - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("i").size shouldBe 8 - cpg.identifier.name("value1").size shouldBe 3 - cpg.identifier.name("value2").size shouldBe 3 - } - - "recognize all literal nodes" in { - cpg.literal.code("1").size shouldBe 3 - cpg.literal.code("2").size shouldBe 1 - cpg.literal.code("8").size shouldBe 1 - cpg.literal.code("7").size shouldBe 1 - cpg.literal.code("9").size shouldBe 1 - cpg.literal.code("5").size shouldBe 1 - cpg.literal.code("0").size shouldBe 3 - cpg.literal.code("1").size shouldBe 3 - cpg.literal.code("10").size shouldBe 1 - cpg.literal.code("100").size shouldBe 2 - } - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 1 - } - } - - "Next statements used as a conditional return for literals" should { - val cpg = code(""" - |grouped_currencies = Money::Currency.all.group_by do |currency| - | next "Major" if MAJOR_CURRENCY_CODES.include?(currency.iso_code) - | "Exotic" - |end - |""".stripMargin) - - "convert the CONTINUE to a RETURN" in { - cpg.controlStructure.controlStructureType(ControlStructureTypes.CONTINUE).size shouldBe 0 - cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).size shouldBe 1 - } - - "return `Major` under the if-statement but return `Exotic` otherwise" in { - val List(ifStmt) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l: @unchecked - val List(ifReturn) = ifStmt.astChildren.isReturn.l: @unchecked - val List(majorLiteral) = ifReturn.astChildren.isLiteral.l: @unchecked - majorLiteral.code shouldBe "\"Major\"" - val List(blockReturn) = ifStmt.astSiblings.isReturn.l: @unchecked - val List(exoticLiteral) = blockReturn.astChildren.isLiteral.l: @unchecked - exoticLiteral.code shouldBe "\"Exotic\"" - } - } - - "Next statements used as a conditional continue for calls" should { - val cpg = code(""" - |for i in 1..10 - | next if i % 2 == 0 - | puts i - |end - |""".stripMargin) - - "retain the CONTINUE under the `next` with no return value" in { - val List(cont: ControlStructure) = - cpg.controlStructure.controlStructureType(ControlStructureTypes.CONTINUE).l: @unchecked - val ifStmt = cont.astParent.asInstanceOf[ControlStructure]: @unchecked - ifStmt.controlStructureType shouldBe ControlStructureTypes.IF - } - - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FieldAccessTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FieldAccessTests.scala deleted file mode 100644 index c348f6e07edb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FieldAccessTests.scala +++ /dev/null @@ -1,50 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class FieldAccessTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "Test class field access" should { - val cpg = code(""" - |class Person - | attr_reader :name, :age - | - | def initialize(name, age) - | @name = name - | @age = age - | end - |end - | - |p = Person.new("name", 66) - |p.age - |""".stripMargin) - - "be correct for field access" ignore { - cpg.call.name("age").size shouldBe 1 - val List(program) = cpg.method.nameExact(":program").l - val List(programBlock) = program.astChildren.isBlock.l - val List(call) = programBlock.astChildren.isCall.codeExact("p.age").l - call.astChildren.isFieldIdentifier.canonicalNameExact("age").size shouldBe 1 - call.astChildren.isIdentifier.nameExact("p").size shouldBe 1 - } - } - - "Test array access" should { - val cpg = code("result = persons[1].age") - - "be correct for filed access" ignore { - cpg.call.name(":program").l - val List(program) = cpg.method.nameExact(":program").l - val List(programBlock) = program.astChildren.isBlock.l - val List(call) = programBlock.astChildren.isCall.l - val List(rowsCall) = call.astChildren.isCall.l - rowsCall.astChildren.isFieldIdentifier.canonicalNameExact("age").size shouldBe 1 - - val List(rowsCallLeft) = rowsCall.astChildren.isCall.l - rowsCallLeft.astChildren.isLiteral.codeExact("1").size shouldBe 1 - rowsCallLeft.astChildren.isIdentifier.nameExact("persons").size shouldBe 1 - call.astChildren.isIdentifier.nameExact("result").size shouldBe 1 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FunctionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FunctionTests.scala deleted file mode 100644 index c4ee925b88b2..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/FunctionTests.scala +++ /dev/null @@ -1,214 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language.* - -class FunctionTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "CPG for code with class methods, members and locals in methods" should { - val cpg = code(""" - |class Person - | attr_accessor :name, :age - | - | def initialize(name, age) - | @name = name - | @age = age - | end - | - | def greet - | puts "Hello, my name is #{@name} and I am #{@age} years old." - | end - | - | def have_birthday - | @age += 1 - | puts "Happy birthday! You are now #{@age} years old." - | end - |end - | - |p = Person.new - |p.greet - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("name").size shouldBe 1 - cpg.identifier.name("age").size shouldBe 1 - cpg.fieldAccess.fieldIdentifier.canonicalName("name").size shouldBe 2 - cpg.fieldAccess.fieldIdentifier.canonicalName("age").size shouldBe 4 - cpg.identifier.size shouldBe 13 // 4 identifier node is for `puts = typeDef(__builtin.puts)` 1 node for class Person = typeDef - } - - "recognize all call nodes" in { - cpg.call.name("greet").size shouldBe 1 - cpg.call.name("puts").size shouldBe 2 - } - - "recognize all method nodes" in { - // Initialize => - cpg.method.name("initialize").size shouldBe 0 - cpg.method.name(Defines.ConstructorMethodName).size shouldBe 1 - cpg.method.name("greet").size shouldBe 1 - cpg.method.name("have_birthday").size shouldBe 1 - } - } - - "CPG for code with square brackets as methods" should { - val cpg = code(""" - |class MyClass < MyBaseClass - | def initialize - | @my_hash = {} - | end - | - | def [](key) - | @my_hash[key.to_s] - | end - | - | def []=(key, value) - | @my_hash[key.to_s] = value - | end - |end - | - |my_object = MyClass.new - | - |""".stripMargin) - - "recognise all method nodes" in { - cpg.method.name("\\[]").size shouldBe 1 - cpg.method.name("\\[]=").size shouldBe 1 - cpg.method.name("initialize").size shouldBe 0 - cpg.method.name(Defines.ConstructorMethodName).size shouldBe 1 - } - - "recognize all call nodes" in { - cpg.call - .name(Operators.assignment) - .size shouldBe 4 // +1 identifier node for TypeRef's assignment - cpg.call.name("to_s").size shouldBe 2 - cpg.call.name(Defines.ConstructorMethodName).size shouldBe 1 - cpg.call.size shouldBe 12 // 1 identifier node for TypeRef's assignment - } - - "recognize all identifier nodes" in { - cpg.fieldAccess.fieldIdentifier.canonicalName("my_hash").size shouldBe 3 - cpg.identifier.name("key").size shouldBe 2 - cpg.identifier.name("value").size shouldBe 1 - cpg.identifier.name("my_object").size shouldBe 1 - /* - * FIXME - * def []=(key, value) gets parsed incorrectly with parser error "no viable alternative at input 'def []=(key, value)'" - * This needs a fix in the parser and update to this UT after the fix - * FIXME - * MyClass is identified as a variableIdentifier and so an identifier. This needs to be fixed - */ - } - } - - "CPG for code with modules" should { - val cpg = code(""" - |module Module1 - | def method1_1 - | end - | def method1_2 - | end - |end - | - |module Module2 - | def method2_1 - | end - | def method2_2 - | end - |end - |""".stripMargin) - - "recognise all method nodes defined in modules" in { - cpg.method.name("method1_1").l.size shouldBe 1 - cpg.method.name("method1_2").l.size shouldBe 1 - cpg.method.name("method2_1").l.size shouldBe 1 - cpg.method.name("method2_2").l.size shouldBe 1 - } - } - - "CPG for code with private/protected/public" should { - val cpg = code(""" - |class SomeClass - | private - | def method1 - | end - | - | protected - | def method2 - | end - | - | public - | def method3 - | end - |end - | - |""".stripMargin) - - "recognise all method nodes" in { - cpg.method - .name("method1") - .size shouldBe 1 - cpg.method - .name("method1") - .size shouldBe 1 - cpg.method - .name("method3") - .size shouldBe 1 - - } - } - - "CPG for code with multiple yields" should { - val cpg = code(""" - |def yield_with_arguments - | x = "something" - | y = "something_else" - | yield(x,y) - |end - | - |yield_with_arguments { |arg1, arg2| puts "Yield block 1 #{arg1} and #{arg2}" } - |yield_with_arguments { |arg1, arg2| puts "Yield block 2 #{arg2} and #{arg1}" } - | - |""".stripMargin) - - "recognise all method nodes" in { - cpg.method - .name("yield_with_arguments") - .size shouldBe 1 - cpg.method - .name("yield_with_arguments_yield") - .size shouldBe 2 - } - - "recognise all call nodes" in { - cpg.call - .name("yield_with_arguments_yield") - .size shouldBe 1 - - cpg.call - .name("puts") - .size shouldBe 2 - } - - "recognise all identifier nodes" in { - cpg.identifier - .name("arg1") - .size shouldBe 2 - - cpg.identifier - .name("arg2") - .size shouldBe 2 - - cpg.identifier - .name("x") - .size shouldBe 2 - - cpg.identifier - .name("y") - .size shouldBe 2 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/IdentifierTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/IdentifierTests.scala deleted file mode 100644 index e4b41ec4c9f6..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/IdentifierTests.scala +++ /dev/null @@ -1,122 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class IdentifierTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "CPG for code with a function call, arguments and function called from function " should { - val cpg = code(""" - | - |def extrareturn() - | ret = 6 - | return ret - |end - | - |def add_three_numbers(num1, num2, num3) - | sum = num1 + num2 + num3 + extrareturn() - | return sum - |end - | - |a = 1 - |b = 2 - |c = 3 - | - |sumOfThree = add_three_numbers( a, b, c ) - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("a").size shouldBe 2 - cpg.identifier.name("b").size shouldBe 2 - cpg.identifier.name("c").size shouldBe 2 - cpg.identifier.name("sumOfThree").size shouldBe 1 - cpg.identifier.name("num1").size shouldBe 1 - cpg.identifier.name("num2").size shouldBe 1 - cpg.identifier.name("num3").size shouldBe 1 - cpg.identifier.name("sum").size shouldBe 2 - cpg.identifier.name("ret").size shouldBe 2 - cpg.identifier.size shouldBe 16 // 2 identifier node is for methodRef's assigment - } - - "identify a single call node" in { - cpg.call.name("add_three_numbers").size shouldBe 1 - } - } - - "CPG for code with expressions of various types" should { - val cpg = code(""" - |a = 1 - |b = 2 if a > 1 - |b = !a - |c = ~a - |e = +a - |f = b**a - |g = a*b - |h = a+b - |i = a >> b - |j = a | b - |k = a & b - |l = a && b - |m = a || b - |n = a .. b - |o = a ... b - |p = ( a > b ) ? c : e - |q = not p - |r = p and q - |s = p or q - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("a").size shouldBe 16 - cpg.identifier.name("b").size shouldBe 13 // unaryExpression - cpg.identifier.name("c").size shouldBe 2 // unaryExpression - cpg.identifier.name("e").size shouldBe 2 // unaryExpression - cpg.identifier.name("f").size shouldBe 1 // powerExpression - cpg.identifier.name("g").size shouldBe 1 // multiplicative Expression - cpg.identifier.name("h").size shouldBe 1 // additive Expression - cpg.identifier.name("i").size shouldBe 1 // bitwise shift Expression - cpg.identifier.name("j").size shouldBe 1 // bitwise or Expression - cpg.identifier.name("k").size shouldBe 1 // bitwise and Expression - cpg.identifier.name("l").size shouldBe 1 // operator and Expression - cpg.identifier.name("m").size shouldBe 1 // operator or Expression - cpg.identifier.name("n").size shouldBe 1 // inclusive range Expression - cpg.identifier.name("o").size shouldBe 1 // exclusive range Expression - cpg.identifier.name("p").size shouldBe 4 // conditionalOperatorExpression - cpg.identifier.name("q").size shouldBe 3 // notExpressionOrCommand - cpg.identifier.name("r").size shouldBe 1 // orAndExpressionOrCommand and part - cpg.identifier.name("s").size shouldBe 1 // orAndExpressionOrCommand or part - cpg.identifier.size shouldBe 52 - } - } - - "CPG for code with identifier and method name conflicts" should { - val cpg = code(""" - |def create_conflict(id) - | puts id - |end - | - |create_conflict = 123 - | - |puts create_conflict - |puts create_conflict + 1 - |puts create_conflict(1) - | - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier - .name("create_conflict") - .size shouldBe 4 // 1 identifier node is for methodRef's assignment - } - - "recognise all call nodes" in { - cpg.call - .name("puts") - .size shouldBe 4 - - cpg.call - .name("create_conflict") - .size shouldBe 1 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/MiscTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/MiscTests.scala deleted file mode 100644 index c40c6ddd2a2c..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/MiscTests.scala +++ /dev/null @@ -1,332 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.Defines -import io.shiftleft.semanticcpg.language.* - -class MiscTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) { - - "CPG for code with BEGIN and END blocks" should { - val cpg = code(""" - |#!/usr/bin/env ruby - | - |# This code block will be executed before the program begins - |BEGIN { - | beginvar = 5 - | beginbool = beginvar > 21 - |} - | - |# This is the main logic of the program - |puts "Hello, world!" - | - |# This code block will be executed after the program finishes - |END { - | endvar = 67 - | endbool = endvar > 23 - |} - |""".stripMargin) - - "recognise all identifier and call nodes" in { - cpg.identifier.name("beginvar").size shouldBe 2 - cpg.identifier.name("endvar").size shouldBe 2 - cpg.identifier.name("beginbool").size shouldBe 1 - cpg.identifier.name("endbool").size shouldBe 1 - cpg.call.name("puts").size shouldBe 1 - cpg.identifier.size shouldBe 7 // 1 identifier node is for `puts = typeDef(__builtin.puts)` - } - } - - "CPG for code with namespace resolution being used" should { - val cpg = code(""" - |Rails.application.configure do - | config.log_formatter = ::Logger::Formatter.new - |end - | - |""".stripMargin) - - "recognise all identifier and call nodes" in { - cpg.call.name("application").size shouldBe 1 - cpg.call.name("configure").size shouldBe 1 - cpg.call.name(Defines.ConstructorMethodName).size shouldBe 1 - cpg.call.name(".scopeResolution").size shouldBe 2 - cpg.identifier.name("Rails").size shouldBe 1 - cpg.identifier.name("config").size shouldBe 1 - cpg.identifier.name("Formatter").size shouldBe 1 - cpg.identifier.name("Logger").size shouldBe 1 - cpg.identifier.name("log_formatter").size shouldBe 1 - cpg.identifier.size shouldBe 5 - } - } - - "CPG for code with defined? keyword" should { - val cpg = code(""" - |radius = 2 - | - |area = 3.14 * radius * radius - | - |# Checking if the variable is defined or not - |# Using defined? keyword - |res1 = defined? radius - |res2 = defined? height - |res3 = defined? area - |res4 = defined? Math::PI - | - |# Displaying results - |puts "Result 1: #{res1}" - |puts "Result 2: #{res2}" - |puts "Result 3: #{res3}" - |puts "Result 4: #{res4}" - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier.name("radius").size shouldBe 4 - cpg.identifier.name("area").size shouldBe 2 - cpg.identifier.name("height").size shouldBe 1 - cpg.identifier.name("res1").size shouldBe 2 - cpg.identifier.name("res2").size shouldBe 2 - cpg.identifier.name("res3").size shouldBe 2 - cpg.identifier.name("res4").size shouldBe 2 - cpg.identifier.name("Math").size shouldBe 1 - cpg.identifier.name("PI").size shouldBe 1 - } - - "recognise all literal nodes" in { - cpg.literal.code("3.14").size shouldBe 1 - cpg.literal.code("2").size shouldBe 1 - cpg.literal.code("Result 1: ").size shouldBe 1 - cpg.literal.code("Result 2: ").size shouldBe 1 - cpg.literal.code("Result 3: ").size shouldBe 1 - cpg.literal.code("Result 4: ").size shouldBe 1 - } - } - - "CPG for code with association statements" should { - val cpg = code(""" - |class Employee < EmployeeBase - | has_many :teams, foreign_key: "team_id", class_name: "Team" - | has_many :worklocations, foreign_key: "location_id", class_name: "WorkLocation" - |end - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.literal - .code("\"team_id\"") - .size shouldBe 1 - cpg.literal - .code("\"location_id\"") - .size shouldBe 1 - } - - "recognise all activeRecordAssociation operator calls" in { - cpg.call - .name(".activeRecordAssociation") - .size shouldBe 4 - } - } - - "CPG for code with class having a scoped constant reference" should { - val cpg = code(""" - |class ModuleName::ClassName - | def some_method - | puts "Inside the method" - | end - |end - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.literal - .code("\"Inside the method\"") - .size shouldBe 1 - } - - "recognise all method nodes" in { - cpg.method - .name("some_method") - .size shouldBe 1 - } - } - - "CPG for code with alias" should { - val cpg = code(""" - |def some_method(arg) - |puts arg - |end - |alias :alias_name :some_method - |alias_name("some param") - |""".stripMargin) - - "recognise all call nodes" in { - cpg.call - .name("puts") - .size shouldBe 1 - cpg.call - .name("some_method") - .size shouldBe 1 - } - } - - "CPG for code with rescue clause" should { - val cpg = code(""" - |begin - | puts "In begin" - |rescue SomeException - | puts "SomeException occurred" - |rescue => exceptionVar - | puts "Caught exception in variable #{exceptionVar}" - |rescue - | puts "Catch-all block" - |end - | - |""".stripMargin) - - "recognise all literal nodes" in { - cpg.literal - .code("\"In begin\"") - .size shouldBe 1 - cpg.literal - .code("\"SomeException occurred\"") - .size shouldBe 1 - cpg.literal - .code("\"Catch-all block\"") - .size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call - .name("puts") - .size shouldBe 4 - } - - "recognise all identifier nodes" in { - cpg.identifier - .name("exceptionVar") - .size shouldBe 2 - } - } - - "CPG for code with addition of method returns" should { - val cpg = code(""" - |def num1; 1; end - |def num2; 2; end - |def num3; 3; end - |x = num1 + num2 + num3 - |puts x - |""".stripMargin) - - "recognise all identifier nodes" in { - cpg.identifier - .name("x") - .size shouldBe 2 - } - - "recognise all call nodes" in { - cpg.call - .name("num1") - .size shouldBe 1 - - cpg.call - .name("num2") - .size shouldBe 1 - - cpg.call - .name("num3") - .size shouldBe 1 - } - } - - "CPG for code with chained constants as argument" should { - val cpg = code(""" - |SomeFramework.someMethod SomeModule::SomeSubModule::submoduleMethod do - |puts "nothing important" - |end - |""".stripMargin) - - "recognise all method nodes" ignore { - cpg.method.name("submoduleMethod2").size shouldBe 1 - } - - "recognise all call nodes" in { - cpg.call - .name("submoduleMethod") - .size shouldBe 1 - - cpg.call - .name("puts") - .size shouldBe 1 - - cpg.call - .name(".scopeResolution") - .size shouldBe 1 - } - } - - "CPG for code with singleton object of some class" ignore { - val cpg = code(""" - |class << "some_class" - |end - |""".stripMargin) - - "recognise all typedecl nodes" in { - cpg.typeDecl.name("some_class").size shouldBe 1 - } - } - - // TODO obj.foo="arg" should be interpreted as obj.foo("arg"). code change pending - "CPG for code with method ending with =" should { - val cpg = code(""" - |class MyClass - | def foo=(value) - | puts value - | end - |end - | - |obj = MyClass.new - |obj.foo="arg" - |""".stripMargin) - - "recognise all call nodes" in { - cpg.call.name("puts").size shouldBe 1 - cpg.call.name(".fieldAccess").size shouldBe 1 - } - - "recognise all method nodes" in { - cpg.method.name("foo=").size shouldBe 1 - } - } - - // expectation is that this should not cause a crash - "CPG for code with method having a singleton class" should { - val cpg = code(""" - |module SomeModule - | def self.someMethod(arg) - | class << arg - | end - | end - |end - |""".stripMargin) - - "recognise all namespace nodes" in { - cpg.namespace.name("SomeModule").size shouldBe 1 - } - } - - "CPG for code with super without arguments" should { - val cpg = code(""" - |class Parent - | def foo(arg) - | end - |end - | - |class Child < Parent - | def foo(arg) - | super - | end - |end - |""".stripMargin) - - "recognise all call nodes" in { - cpg.call.name(".super").size shouldBe 1 - cpg.method.name("foo").size shouldBe 2 - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/RubyMethodFullNameTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/RubyMethodFullNameTests.scala deleted file mode 100644 index 4ce813878f75..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/deprecated/querying/RubyMethodFullNameTests.scala +++ /dev/null @@ -1,88 +0,0 @@ -package io.joern.rubysrc2cpg.deprecated.querying - -import io.joern.rubysrc2cpg.Config -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* -import org.scalatest.BeforeAndAfterAll - -class RubyMethodFullNameTests extends RubyCode2CpgFixture(useDeprecatedFrontend = true) with BeforeAndAfterAll { - - private val config = Config().withDownloadDependencies(true) - - "Code for method full name when method present in module" should { - val cpg = code( - """ - |require "dummy_logger" - | - |v = Main_module::Main_outer_class.new - |v.first_fun("value") - | - |g = Help.new - |g.help_print() - | - |""".stripMargin, - "main.rb" - ) - .moreCode( - """ - |source 'https://rubygems.org' - |gem 'dummy_logger' - | - |""".stripMargin, - "Gemfile" - ) - .withConfig(config) - "recognise call node" in { - cpg.call.name("first_fun").l.size shouldBe 1 - } - - "recognise methodFullName for call Node" ignore { - if (!scala.util.Properties.isWin) { - cpg.call.name("first_fun").head.methodFullName should equal( - "dummy_logger::program:Main_module:Main_outer_class:first_fun" - ) - cpg.call - .name("help_print") - .head - .methodFullName shouldBe "dummy_logger::program:Help:help_print" - } - } - } - - "Code for method full name when method present in other file" should { - val cpg = code( - """ - |require_relative "util/help.rb" - | - |v = Outer.new - |v.printValue() - | - |""".stripMargin, - "main.rb" - ) - .moreCode( - """ - |class Outer - | def printValue() - | puts "print" - | end - |end - |""".stripMargin, - Seq("util", "help.rb").mkString(java.io.File.separator) - ) - .withConfig(config) - - "recognise call node" in { - cpg.call.name("printValue").size shouldBe 1 - } - - "recognise method full name for call node" ignore { - if (!scala.util.Properties.isWin) { - cpg.call - .name("printValue") - .head - .methodFullName shouldBe "util/help.rb::program:Outer:printValue" - } - } - } -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/io/RubySrc2CpgHTTPServerTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/io/RubySrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..5538dd30b360 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/io/RubySrc2CpgHTTPServerTests.scala @@ -0,0 +1,83 @@ +package io.joern.rubysrc2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class RubySrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("rubysrc2cpgTestsHttpTest") + val file = dir / "main.rb" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse(""""Hello, World!"""") + file.writeText(s""" + |def main + | puts $indexStr + |end + |""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.rubysrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.rubysrc2cpg.Main.stop() + } + + "Using rubysrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("rubysrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("""puts "Hello, World!"""") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("rubysrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain(s"puts $index") + } + } + } + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPassTests.scala new file mode 100644 index 000000000000..7ac475dfd55d --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ConfigFileCreationPassTests.scala @@ -0,0 +1,60 @@ +package io.joern.rubysrc2cpg.passes + +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class ConfigFileCreationPassTests extends RubyCode2CpgFixture { + + "yaml files should be included" in { + val cpg = code( + """ + |foo: + | bar + |""".stripMargin, + "config.yaml" + ) + + val config = cpg.configFile.name("config.yaml").head + config.content should include("foo:") + } + + "yml files should be included" in { + val cpg = code( + """ + |foo: + | bar + |""".stripMargin, + "config.yml" + ) + + val config = cpg.configFile.name("config.yml").head + config.content should include("foo:") + } + + "xml files should be included" in { + val cpg = code( + """ + | + |

bar

+ | + |""".stripMargin, + "config.xml" + ) + + val config = cpg.configFile.name("config.xml").head + config.content should include("

bar

") + } + + "erb files should be included" in { + val cpg = code( + """ + |<%= 1 + 2 %> + |""".stripMargin, + "foo.erb" + ) + + val config = cpg.configFile.name("foo.erb").head + config.content should include("1 + 2") + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala index bb06adf658b3..df13ed374451 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala @@ -1,11 +1,11 @@ package io.joern.rubysrc2cpg.passes +import io.joern.rubysrc2cpg.passes.Defines.Main +import io.joern.rubysrc2cpg.passes.GlobalTypes.{corePrefix, kernelPrefix} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.x2cpg.Defines as XDefines -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language.importresolver.* import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.importresolver.* import scala.collection.immutable.List @@ -59,14 +59,14 @@ class RubyInternalTypeRecoveryTests extends RubyCode2CpgFixture(withPostProcessi "resolve 'print' and 'puts' StubbedRubyType calls" in { val List(printCall) = cpg.call("print").l - printCall.methodFullName shouldBe s"$kernelPrefix:print" + printCall.methodFullName shouldBe Defines.prefixAsKernelDefined("print") val List(maxCall) = cpg.call("puts").l - maxCall.methodFullName shouldBe s"$kernelPrefix:puts" + maxCall.methodFullName shouldBe Defines.prefixAsKernelDefined("puts") } "present the declared method name when a built-in with the same name is used in the same compilation unit" in { val List(absCall) = cpg.call("sleep").l - absCall.methodFullName shouldBe "main.rb:::program:sleep" + absCall.methodFullName shouldBe s"main.rb:$Main.sleep" } } @@ -88,8 +88,8 @@ class RubyInternalTypeRecoveryTests extends RubyCode2CpgFixture(withPostProcessi "propagate function return types" in { inside(cpg.method.name("func2?").l) { case func :: func2 :: Nil => - func.methodReturn.typeFullName shouldBe s"$kernelPrefix.String" - func2.methodReturn.typeFullName shouldBe s"$kernelPrefix.String" + func.methodReturn.typeFullName shouldBe Defines.prefixAsCoreType("String") + func2.methodReturn.typeFullName shouldBe Defines.prefixAsCoreType("String") case xs => fail(s"Expected 2 functions, got [${xs.name.mkString(",")}]") } } @@ -97,7 +97,7 @@ class RubyInternalTypeRecoveryTests extends RubyCode2CpgFixture(withPostProcessi "propagate return type to identifier c" in { inside(cpg.identifier.name("c").l) { case cIdent :: Nil => - cIdent.typeFullName shouldBe s"$kernelPrefix.String" + cIdent.typeFullName shouldBe Defines.prefixAsCoreType("String") case xs => fail(s"Expected one identifier for c, got [${xs.name.mkString(",")}]") } } @@ -135,13 +135,13 @@ class RubyInternalTypeRecoveryTests extends RubyCode2CpgFixture(withPostProcessi case funcAssignment :: constructAssignment :: tmpAssignment :: Nil => inside(funcAssignment.argument.l) { case (lhs: Identifier) :: rhs :: Nil => - lhs.typeFullName shouldBe s"$kernelPrefix.String" + lhs.typeFullName shouldBe Defines.prefixAsCoreType("String") case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}] ") } inside(constructAssignment.argument.l) { case (lhs: Identifier) :: rhs :: Nil => - lhs.typeFullName shouldBe "test2.rb:::program.Test2A" + lhs.typeFullName shouldBe s"test2.rb:$Main.Test2A" case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") @@ -167,12 +167,13 @@ class RubyInternalTypeRecoveryTests extends RubyCode2CpgFixture(withPostProcessi "propagate to identifier" ignore { inside(cpg.identifier.name("(a|b)").l) { case aIdent :: bIdent :: Nil => - aIdent.typeFullName shouldBe "Test0.rb:::program.A" - bIdent.typeFullName shouldBe "Test0.rb:::program.A" + aIdent.typeFullName shouldBe s"Test0.rb:$Main.A" + bIdent.typeFullName shouldBe s"Test0.rb:$Main.A" case xs => fail(s"Expected one identifier, got [${xs.name.mkString(",")}]") } } } + } class RubyExternalTypeRecoveryTests @@ -195,7 +196,7 @@ class RubyExternalTypeRecoveryTests // TODO: Revisit "be present in (Case 1)" ignore { cpg.identifier("sg").lineNumber(5).typeFullName.l shouldBe List("sendgrid-ruby.SendGrid.API") - cpg.call("client").methodFullName.headOption shouldBe Option("sendgrid-ruby.SendGrid.API:client") + cpg.call("client").methodFullName.headOption shouldBe Option("sendgrid-ruby.SendGrid.API.client") } "resolve correct imports via tag nodes" in { @@ -219,7 +220,7 @@ class RubyExternalTypeRecoveryTests "be present in (Case 2)" ignore { cpg.call("post").methodFullName.l shouldBe List( - "sendgrid-ruby::program.SendGrid.API.client.mail.anonymous.post" + s"sendgrid-ruby.$Main.SendGrid.API.client.mail.anonymous.post" ) } } @@ -255,9 +256,9 @@ class RubyExternalTypeRecoveryTests "resolve 'x' and 'y' locally under foo.rb" in { val Some(x) = cpg.identifier("x").where(_.file.name(".*foo.*")).headOption: @unchecked - x.typeFullName shouldBe s"$kernelPrefix.Integer" + x.typeFullName shouldBe Defines.prefixAsCoreType("Integer") val Some(y) = cpg.identifier("y").where(_.file.name(".*foo.*")).headOption: @unchecked - y.typeFullName shouldBe s"$kernelPrefix.String" + y.typeFullName shouldBe Defines.prefixAsCoreType("String") } "resolve 'FooModule.x' and 'FooModule.y' field access primitive types correctly" in { @@ -268,9 +269,9 @@ class RubyExternalTypeRecoveryTests .name("z") .l z1.typeFullName shouldBe "ANY" - z1.dynamicTypeHintFullName shouldBe Seq(s"$kernelPrefix.Integer", s"$kernelPrefix.String") + z1.dynamicTypeHintFullName shouldBe Seq(Defines.prefixAsCoreType("Integer"), Defines.prefixAsCoreType("String")) z2.typeFullName shouldBe "ANY" - z2.dynamicTypeHintFullName shouldBe Seq(s"$kernelPrefix.Integer", s"$kernelPrefix.String") + z2.dynamicTypeHintFullName shouldBe Seq(Defines.prefixAsCoreType("Integer"), Defines.prefixAsCoreType("String")) } "resolve 'FooModule.d' field access object types correctly" ignore { @@ -280,7 +281,7 @@ class RubyExternalTypeRecoveryTests .isIdentifier .name("d") .headOption: @unchecked - d.typeFullName shouldBe "dbi::program.DBI.connect." + d.typeFullName shouldBe "dbi.$Main.DBI.connect." d.dynamicTypeHintFullName shouldBe Seq() } @@ -291,7 +292,7 @@ class RubyExternalTypeRecoveryTests .isCall .name("select_one") .l - d.methodFullName shouldBe "dbi::program.DBI.connect..select_one" + d.methodFullName shouldBe "dbi.$Main.DBI.connect..select_one" d.dynamicTypeHintFullName shouldBe Seq() d.callee(NoResolve).isExternal.headOption shouldBe Some(true) } @@ -300,10 +301,10 @@ class RubyExternalTypeRecoveryTests "resolve correct imports via tag nodes" ignore { val List(foo: ResolvedTypeDecl) = cpg.file(".*foo.rb").ast.isCall.where(_.referencedImports).tag._toEvaluatedImport.toList: @unchecked - foo.fullName shouldBe "dbi::program.DBI" + foo.fullName shouldBe s"dbi.$Main.DBI" val List(bar: ResolvedTypeDecl) = cpg.file(".*bar.rb").ast.isCall.where(_.referencedImports).tag._toEvaluatedImport.toList: @unchecked - bar.fullName shouldBe "foo.rb::program.FooModule" + bar.fullName shouldBe s"foo.rb.$Main.FooModule" } } @@ -341,7 +342,7 @@ class RubyExternalTypeRecoveryTests val Some(log) = cpg.identifier("log").headOption: @unchecked log.typeFullName shouldBe "logger.Logger" val List(errorCall) = cpg.call("error").l - errorCall.methodFullName shouldBe "logger.Logger:error" + errorCall.methodFullName shouldBe "logger.Logger.error" } } @@ -360,12 +361,12 @@ class RubyExternalTypeRecoveryTests "resolved the type of call" in { val Some(create) = cpg.call("create").headOption: @unchecked - create.methodFullName shouldBe "stripe.rb:::program.Stripe.Customer:create" + create.methodFullName shouldBe s"stripe.rb:$Main.Stripe.Customer.create" } "resolved the type of identifier" in { val Some(customer) = cpg.identifier("customer").headOption: @unchecked - customer.typeFullName shouldBe "stripe::program.Stripe.Customer.create." + customer.typeFullName shouldBe s"stripe.$Main.Stripe.Customer.create." } } @@ -384,11 +385,11 @@ class RubyExternalTypeRecoveryTests .moreCode(RubyExternalTypeRecoveryTests.LOGGER_GEMFILE, "Gemfile") "have a correct type for call `connect`" in { - cpg.call("error").methodFullName.l shouldBe List("logger.Logger:error") + cpg.call("error").methodFullName.l shouldBe List("logger.Logger.error") } "have a correct type for identifier `d`" in { - cpg.identifier("e").typeFullName.l shouldBe List("logger.Logger:error.") + cpg.identifier("e").typeFullName.l shouldBe List("logger.Logger.error.") } } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AccessModifierTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AccessModifierTests.scala new file mode 100644 index 000000000000..f785b9b07df2 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AccessModifierTests.scala @@ -0,0 +1,107 @@ +package io.joern.rubysrc2cpg.querying + +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.Call +import io.shiftleft.semanticcpg.language.* + +class AccessModifierTests extends RubyCode2CpgFixture { + + "methods defined on the

level are private" in { + val cpg = code(""" + |def foo + |end + |""".stripMargin) + + cpg.method("foo").head.isPrivate.size shouldBe 1 + } + + "a method should be public by default, with the `initialize` default constructor private" in { + val cpg = code(""" + |class Foo + | def bar + | end + |end + |""".stripMargin) + + cpg.method("bar").head.isPublic.size shouldBe 1 + cpg.method(Defines.Initialize).head.isPrivate.size shouldBe 1 + cpg.method(Defines.TypeDeclBody).head.isPrivate.size shouldBe 1 + } + + "an access modifier should affect the visibility of subsequent method definitions" in { + val cpg = code(""" + |class Foo + | def bar + | end + | + | private + | + | def baz + | end + | + | def faz + | end + | + |end + | + |class Baz + | def test1 + | end + | + | protected + | + | def test2 + | end + |end + |""".stripMargin) + + cpg.method("bar").head.isPublic.size shouldBe 1 + + cpg.method("baz").head.isPrivate.size shouldBe 1 + cpg.method("faz").head.isPrivate.size shouldBe 1 + + cpg.method("test1").head.isPublic.size shouldBe 1 + cpg.method("test2").head.isProtected.size shouldBe 1 + } + + "nested types should 'remember' their access modifier mode according to scope" in { + val cpg = code(""" + |class Foo + | private + | + | class Bar + | + | public + | def baz + | end + | + | end + | + | def test + | end + |end + |""".stripMargin) + + cpg.method("baz").isPublic.size shouldBe 1 + cpg.method("test").isPrivate.size shouldBe 1 + } + + "an identifier sharing the same name as an access modifier in an unambiguous spot should not be confused" in { + val cpg = code(""" + | def message_params + | { + | private: @private + | } + | end + |""".stripMargin) + + val privateKey = cpg.literal(":private").head + val indexAccess = privateKey.astParent.asInstanceOf[Call] + indexAccess.name shouldBe Operators.indexAccess + indexAccess.methodFullName shouldBe Operators.indexAccess + indexAccess.code shouldBe "[:private]" + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala index 86d9008a8a13..752348c50bf2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala @@ -1,101 +1,169 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.GlobalTypes.{builtinPrefix, kernelPrefix} +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.RubyOperators +import io.joern.rubysrc2cpg.passes.GlobalTypes.corePrefix import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators + +import io.joern.x2cpg.Defines as XDefines +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment class ArrayTests extends RubyCode2CpgFixture { - "`[]` is represented by an `arrayInitializer` operator call" in { + "`[]` is represented by an alloc constructor call" in { val cpg = code(""" |[] |""".stripMargin) - val List(arrayCall) = cpg.call.l + val List(arrayCall) = cpg.call.name("initialize").l - arrayCall.methodFullName shouldBe Operators.arrayInitializer + arrayCall.methodFullName shouldBe XDefines.DynamicCallUnknownFullName arrayCall.code shouldBe "[]" arrayCall.lineNumber shouldBe Some(2) + inside(arrayCall.argument.l) { case callBase :: Nil => + callBase.code shouldBe "" + } + + inside(arrayCall.astSiblings.l) { case (_: Local) :: (asgnCall: Call) :: (_: Identifier) :: Nil => + val asgn = asgnCall.asInstanceOf[Assignment] + val rhs = asgn.source.asInstanceOf[Call] + rhs.methodFullName shouldBe Operators.alloc + rhs.code shouldBe "[]" + } } - "`[1]` is represented by an `arrayInitializer` operator call with arguments `1`" in { + "`[1]` is represented by an alloc call with arguments `1`" in { val cpg = code(""" |[1] |""".stripMargin) - val List(arrayCall) = cpg.call.l + val List(arrayCall) = cpg.call.name("initialize").l - arrayCall.methodFullName shouldBe Operators.arrayInitializer + arrayCall.methodFullName shouldBe XDefines.DynamicCallUnknownFullName arrayCall.code shouldBe "[1]" arrayCall.lineNumber shouldBe Some(2) + inside(arrayCall.argument.l) { case callBase :: Nil => + callBase.code shouldBe "" + } - val List(one) = arrayCall.argument.l + inside(arrayCall.parentBlock.parentBlock.lastOption) { case Some(loweringBlock: Block) => + val asgn = loweringBlock.astChildren.assignment.last + val lhs = asgn.target.asInstanceOf[Call] + lhs.methodFullName shouldBe Operators.indexAccess + lhs.code shouldBe "[0]" - one.code shouldBe "1" + val rhs = asgn.source.asInstanceOf[Literal] + rhs.code shouldBe "1" + } } - "`[1,2,]` is represented by an `arrayInitializer` operator call with arguments `1`, `2`" in { + "`[1,2,]` is represented by an alloc call with arguments `1`, `2`" in { val cpg = code(""" |[1,2,] |""".stripMargin) - val List(arrayCall) = cpg.call.l - - arrayCall.methodFullName shouldBe Operators.arrayInitializer - arrayCall.code shouldBe "[1,2,]" - arrayCall.lineNumber shouldBe Some(2) - - val List(one, two) = arrayCall.argument.l + inside(cpg.block.codeExact("[1,2,]").astChildren.assignment.where(_.source.isLiteral).l) { + case asgn1 :: asgn2 :: Nil => + asgn1.code shouldBe "[0] = 1" + asgn1.source.asInstanceOf[Literal].code shouldBe "1" - one.code shouldBe "1" - two.code shouldBe "2" + asgn2.code shouldBe "[1] = 2" + asgn2.source.asInstanceOf[Literal].code shouldBe "2" + } } - "`%w{}` is represented by an `arrayInitializer` operator call" in { + "`%w{}` is represented by an alloc call" in { val cpg = code(""" |%w{} |""".stripMargin) - val List(arrayCall) = cpg.call.l + val List(arrayCall) = cpg.block.codeExact("%w{}").l - arrayCall.methodFullName shouldBe Operators.arrayInitializer arrayCall.code shouldBe "%w{}" arrayCall.lineNumber shouldBe Some(2) - arrayCall.argument.isEmpty shouldBe true + arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).isEmpty shouldBe true } - "`%w?foo?` is represented by an `arrayInitializer` operator call with arguments 'foo'" in { + "`%w?foo?` is represented by an alloc call with arguments 'foo'" in { val cpg = code(""" |%w?foo? |""".stripMargin) - val List(arrayCall) = cpg.call.name(Operators.arrayInitializer).l + val List(arrayCall, _) = cpg.block.codeExact("%w?foo?").l arrayCall.code shouldBe "%w?foo?" arrayCall.lineNumber shouldBe Some(2) + val asgns = arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l - val List(foo) = arrayCall.argument.isLiteral.l - foo.code shouldBe "foo" - foo.typeFullName shouldBe s"$kernelPrefix.String" + inside(asgns.map(_.source)) { case (foo: Literal) :: Nil => + foo.code shouldBe "foo" + foo.typeFullName shouldBe Defines.prefixAsCoreType("String") + } } - "`%i(x y)` is represented by an `arrayInitializer` operator call with arguments `:x`, `:y`" in { + "`%i(x y)` is represented by an alloc call with arguments `:x`, `:y`" in { val cpg = code(""" |%i(x y) |""".stripMargin) - val List(arrayCall) = cpg.call.name(Operators.arrayInitializer).l + val List(arrayCall, _) = cpg.block.codeExact("%i(x y)").l arrayCall.code shouldBe "%i(x y)" arrayCall.lineNumber shouldBe Some(2) + val asgns = arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l - val List(x, y) = arrayCall.argument.isLiteral.l - x.code shouldBe "x" - x.typeFullName shouldBe y.typeFullName + inside(asgns.map(_.source)) { case (x: Literal) :: (y: Literal) :: Nil => + x.code shouldBe ":x" + x.typeFullName shouldBe Defines.prefixAsCoreType("Symbol") - y.code shouldBe "y" - y.typeFullName shouldBe s"$kernelPrefix.Symbol" + y.code shouldBe ":y" + y.typeFullName shouldBe Defines.prefixAsCoreType("Symbol") + } + } + + "%W is represented an alloc call" in { + val cpg = code("""%W(x#{1 + 3} y#{23} z) + |""".stripMargin) + + val List(arrayCall, _) = cpg.block.codeExact("%W(x#{1 + 3} y#{23} z)").l + + arrayCall.code shouldBe "%W(x#{1 + 3} y#{23} z)" + arrayCall.lineNumber shouldBe Some(1) + val asgns = arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + + inside(asgns.map(_.source)) { case (xFmt: Call) :: (yFmt: Call) :: (zLit: Literal) :: Nil => + xFmt.name shouldBe Operators.formatString + xFmt.typeFullName shouldBe Defines.prefixAsCoreType(Defines.String) + + yFmt.name shouldBe Operators.formatString + yFmt.typeFullName shouldBe Defines.prefixAsCoreType(Defines.String) + + val List(xFmtStr, xAddFmtStr) = xFmt.astChildren.isCall.l + xFmtStr.name shouldBe Operators.formattedValue + xAddFmtStr.name shouldBe Operators.formattedValue + + val List(xFmtStrAdd) = xAddFmtStr.astChildren.isCall.l + xFmtStrAdd.name shouldBe Operators.addition + + val List(lhs, rhs) = xFmtStrAdd.argument.l + lhs.code shouldBe "1" + rhs.code shouldBe "3" + + val List(yFmtStr, yFmt23) = yFmt.astChildren.isCall.l + yFmtStr.name shouldBe Operators.formattedValue + yFmt23.name shouldBe Operators.formattedValue + + val List(yFmtStrLit: Literal) = yFmt23.argument.l: @unchecked + yFmtStrLit.code shouldBe "23" + + zLit.code shouldBe "z" + zLit.typeFullName shouldBe Defines.prefixAsCoreType("String") + } } "an implicit array constructor (Array::[]) should be lowered to an array initializer" in { @@ -106,8 +174,8 @@ class ArrayTests extends RubyCode2CpgFixture { inside(cpg.call.nameExact("[]").l) { case bracketCall :: Nil => bracketCall.name shouldBe "[]" - bracketCall.methodFullName shouldBe s"$builtinPrefix.Array.[]" - bracketCall.typeFullName shouldBe s"$builtinPrefix.Array" + bracketCall.methodFullName shouldBe s"${Defines.prefixAsCoreType("Array")}.[]" + bracketCall.typeFullName shouldBe Defines.prefixAsCoreType("Array") inside(bracketCall.argument.l) { case _ :: one :: two :: three :: Nil => @@ -121,4 +189,117 @@ class ArrayTests extends RubyCode2CpgFixture { } + "%I array" in { + val cpg = code("%I(test_#{1} test_2)") + + val List(arrayCall, _) = cpg.block.codeExact("%I(test_#{1} test_2)").l + + arrayCall.code shouldBe "%I(test_#{1} test_2)" + arrayCall.lineNumber shouldBe Some(1) + val asgns = arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + + inside(asgns.map(_.source)) { case (test1Fmt: Call) :: (test2: Literal) :: Nil => + test1Fmt.name shouldBe Operators.formatString + test1Fmt.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Symbol) + test1Fmt.code shouldBe "test_#{1}" + + val List(test1FmtLit, test1FmtSymbol) = test1Fmt.astChildren.isCall.l + test1FmtSymbol.name shouldBe Operators.formattedValue + test1FmtSymbol.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Symbol) + test1FmtSymbol.code shouldBe "#{1}" + + test1FmtLit.name shouldBe Operators.formattedValue + + val List(test1FmtFinal: Literal) = test1FmtLit.argument.l: @unchecked + test1FmtFinal.code shouldBe "test_" + + test2.code shouldBe ":test_2" + test2.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Symbol) + } + } + + "shift-left operator interpreted as a call (append)" in { + val cpg = code("[1, 2, 3] << 4") + + inside(cpg.call("<<").headOption) { + case Some(append) => + append.name shouldBe "<<" + append.methodFullName shouldBe XDefines.DynamicCallUnknownFullName + append.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + append.argument(0).code shouldBe "[1, 2, 3]" + append.argument(1).code shouldBe "4" + case None => fail(s"Expected call `<<`") + } + } + + "Array bodies with mixed elements" in { + val cpg = code("[1, 2 => 1, 2 => 3]") + + val List(arrayCall, _) = cpg.block.codeExact("[1, 2 => 1, 2 => 3]").l + + arrayCall.code shouldBe "[1, 2 => 1, 2 => 3]" + arrayCall.lineNumber shouldBe Some(1) + val asgns = arrayCall.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + + inside(asgns.map(_.source)) { case (argLit: Literal) :: (argAssoc: Call) :: (argAssoc2: Call) :: Nil => + argLit.code shouldBe "1" + + argAssoc.code shouldBe "2 => 1" + argAssoc.methodFullName shouldBe Defines.RubyOperators.association + + argAssoc2.code shouldBe "2 => 3" + argAssoc2.methodFullName shouldBe Defines.RubyOperators.association + } + } + + "Array with mixed elements" in { + val cpg = code(""" + |[ + | *::ApplicationSettingsHelper.visible_attributes, + | { default_branch_protection_defaults: [ + | :allow_force_push, + | :developer_can_initial_push, + | { + | allowed_to_merge: [:access_level], + | allowed_to_push: [:access_level] + | } + | ] }, + | :can_create_organization, + | *::ApplicationSettingsHelper.some_other_attributes, + |] + |""".stripMargin) + + val List(arrayCall, _) = cpg.block + .codeExact("""[ + | *::ApplicationSettingsHelper.visible_attributes, + | { d...""".stripMargin) + .l + + arrayCall.code shouldBe + """[ + | *::ApplicationSettingsHelper.visible_attributes, + | { d...""".stripMargin + arrayCall.lineNumber shouldBe Some(2) + val asgns = arrayCall.astChildren + .collect { case x: Call if x.name == Operators.assignment => x.asInstanceOf[Assignment] } + .where(_.target.isCall.nameExact(Operators.indexAccess)) + .l + + inside(asgns.map(_.source)) { + case (splatArgOne: Call) :: (hashLiteralArg: Block) :: (symbolArg: Literal) :: (splatArgTwo: Call) :: Nil => + splatArgOne.methodFullName shouldBe RubyOperators.splat + splatArgOne.code shouldBe "*::ApplicationSettingsHelper.visible_attributes" + + symbolArg.code shouldBe ":can_create_organization" + symbolArg.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Symbol) + + splatArgTwo.methodFullName shouldBe RubyOperators.splat + splatArgTwo.code shouldBe "*::ApplicationSettingsHelper.some_other_attributes" + + val List(hashInitAssignment: Call, _) = + hashLiteralArg.astChildren.isCall.name(Operators.assignment).l: @unchecked + val List(_: Identifier, hashInitCall: Call) = hashInitAssignment.argument.l: @unchecked + hashInitCall.methodFullName shouldBe RubyOperators.hashInitializer + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AttributeAccessorTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AttributeAccessorTests.scala new file mode 100644 index 000000000000..348bf22ed279 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/AttributeAccessorTests.scala @@ -0,0 +1,63 @@ +package io.joern.rubysrc2cpg.querying + +import io.joern.x2cpg.Defines +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{Operators, DispatchTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier} +import io.shiftleft.semanticcpg.language.* + +class AttributeAccessorTests extends RubyCode2CpgFixture { + + "`x.y=1` is approximated by a `x.y =` assignment with argument `1`" in { + val cpg = code("""x = Foo.new + |x.y = 1 + |""".stripMargin) + + inside(cpg.assignment.where(_.source.isLiteral.codeExact("1")).l) { + case xyAssign :: Nil => + xyAssign.lineNumber shouldBe Some(2) + xyAssign.code shouldBe "x.y = 1" + + val fieldTarget = xyAssign.target.asInstanceOf[Call] + fieldTarget.code shouldBe "x.y" + fieldTarget.name shouldBe Operators.fieldAccess + fieldTarget.methodFullName shouldBe Operators.fieldAccess + fieldTarget.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + inside(fieldTarget.argument.l) { + case (base: Identifier) :: (field: FieldIdentifier) :: Nil => + base.name shouldBe "x" + field.canonicalName shouldBe "@y" + field.code shouldBe "y" + case xs => fail("Expected field access to have two targets") + } + case xs => fail("Expected a single assignment to the literal `1`") + } + } + + "`x.y` is represented by a field access `x.y`" in { + val cpg = code("""x = Foo.new + |a = x.y + |b = x.z() + |""".stripMargin) + // Test the field access + inside(cpg.fieldAccess.lineNumber(2).codeExact("x.y").l) { + case xyCall :: Nil => + xyCall.lineNumber shouldBe Some(2) + xyCall.code shouldBe "x.y" + xyCall.methodFullName shouldBe Operators.fieldAccess + xyCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + case xs => fail("Expected a single field access for `x.y`") + } + // Test an explicit call with parenthesis + inside(cpg.call("z").lineNumber(3).l) { + case xzCall :: Nil => + xzCall.lineNumber shouldBe Some(3) + xzCall.code shouldBe "x.z()" + xzCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName + xzCall.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + case xs => fail("Expected a single call for `x.z()`") + } + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala index 7286d66ef50a..7c8a35653ef1 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala @@ -1,12 +1,12 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.{GlobalTypes, Defines as RubyDefines} -import io.joern.rubysrc2cpg.passes.Defines.RubyOperators +import io.joern.rubysrc2cpg.passes.Defines.{Main, RubyOperators} import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix +import io.joern.rubysrc2cpg.passes.{GlobalTypes, Defines as RubyDefines} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators} import io.shiftleft.semanticcpg.language.* class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { @@ -19,7 +19,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val List(puts) = cpg.call.name("puts").l puts.lineNumber shouldBe Some(2) puts.code shouldBe "puts 'hello'" - puts.methodFullName shouldBe s"$kernelPrefix:puts" + puts.methodFullName shouldBe s"$kernelPrefix.puts" puts.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH val List(selfReceiver: Identifier, hello: Literal) = puts.argument.l: @unchecked @@ -53,7 +53,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val List(puts) = cpg.call.name("puts").l puts.lineNumber shouldBe Some(2) puts.code shouldBe "Kernel.puts 'hello'" - puts.methodFullName shouldBe s"$kernelPrefix:puts" + puts.methodFullName shouldBe s"$kernelPrefix.puts" puts.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH val List(kernelRec: Call) = puts.receiver.l: @unchecked @@ -71,7 +71,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val List(atan2) = cpg.call.name("atan2").l atan2.lineNumber shouldBe Some(3) atan2.code shouldBe "Math.atan2(1, 1)" - atan2.methodFullName shouldBe s"${GlobalTypes.builtinPrefix}.Math:atan2" + atan2.methodFullName shouldBe s"${RubyDefines.prefixAsCoreType("Math")}.atan2" atan2.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH val List(mathRec: Call) = atan2.receiver.l: @unchecked @@ -79,7 +79,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { mathRec.typeFullName shouldBe Defines.Any mathRec.code shouldBe s"Math.atan2" - mathRec.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe s"${GlobalTypes.builtinPrefix}.Math" + mathRec.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe RubyDefines.prefixAsCoreType("Math") mathRec.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "atan2" } @@ -155,35 +155,38 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { "a simple object instantiation" should { val cpg = code("""class A + | def initialize(a, b) + | end |end | - |a = A.new + |a = A.new 1, 2 |""".stripMargin) - "create an assignment from `a` to an invocation block" in { - inside(cpg.method(":program").assignment.where(_.target.isIdentifier.name("a")).l) { + "create an assignment from `a` to an alloc lowering invocation block" in { + inside(cpg.method.isModule.assignment.and(_.target.isIdentifier.name("a"), _.source.isBlock).l) { case assignment :: Nil => - assignment.code shouldBe "a = A.new" + assignment.code shouldBe "a = A.new 1, 2" inside(assignment.argument.l) { case (a: Identifier) :: (_: Block) :: Nil => a.name shouldBe "a" - a.dynamicTypeHintFullName should contain("Test0.rb:::program.A") + a.dynamicTypeHintFullName should contain(s"Test0.rb:$Main.A") case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected a single assignment, got [${xs.code.mkString(",")}]") } } - "create an assignment from a temp variable to the call" in { - inside(cpg.method(":program").assignment.where(_.target.isIdentifier.name("")).l) { + "create an assignment from a temp variable to the alloc call" in { + inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("")).l) { case assignment :: Nil => inside(assignment.argument.l) { case (a: Identifier) :: (alloc: Call) :: Nil => - a.name shouldBe "" + a.name shouldBe "" alloc.name shouldBe Operators.alloc alloc.methodFullName shouldBe Operators.alloc - alloc.code shouldBe "A.new" + alloc.code shouldBe "A.new 1, 2" + alloc.argument.size shouldBe 0 case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected a single assignment, got [${xs.code.mkString(",")}]") @@ -191,15 +194,71 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { } "create a call to the object's constructor, with the temp variable receiver" in { - inside(cpg.call.nameExact("new").l) { + inside(cpg.call.nameExact(RubyDefines.Initialize).l) { case constructor :: Nil => inside(constructor.argument.l) { - case (a: Identifier) :: Nil => + case (a: Identifier) :: (one: Literal) :: (two: Literal) :: Nil => + a.name shouldBe "" + a.typeFullName shouldBe s"Test0.rb:$Main.A" + a.argumentIndex shouldBe 0 + + one.code shouldBe "1" + two.code shouldBe "2" + case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") + } + + val recv = constructor.receiver.head.asInstanceOf[Call] + recv.methodFullName shouldBe Operators.fieldAccess + recv.name shouldBe Operators.fieldAccess + recv.code shouldBe s"A.${RubyDefines.Initialize}" + + recv.argument(1).label shouldBe NodeTypes.CALL + recv.argument(1).code shouldBe "self.A" + recv.argument(2).label shouldBe NodeTypes.FIELD_IDENTIFIER + recv.argument(2).code shouldBe RubyDefines.Initialize + case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]") + } + } + } + + "an object instantiation from some expression" should { + val cpg = code("""def foo + | params[:type].constantize.new(path) + |end + |""".stripMargin) + + "create a call node on the receiver end of the constructor lowering" in { + inside(cpg.call.nameExact(RubyDefines.Initialize).l) { + case constructor :: Nil => + inside(constructor.argument.l) { + case (a: Identifier) :: (selfPath: Call) :: Nil => a.name shouldBe "" - a.typeFullName shouldBe "Test0.rb:::program.A" + a.typeFullName shouldBe Defines.Any a.argumentIndex shouldBe 0 + + selfPath.code shouldBe "self.path" case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") } + + val recv = constructor.receiver.head.asInstanceOf[Call] + recv.methodFullName shouldBe Operators.fieldAccess + recv.name shouldBe Operators.fieldAccess + recv.code shouldBe s"( = params[:type].constantize).${RubyDefines.Initialize}" + + recv.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe RubyDefines.Initialize + + inside(recv.argument(1).start.isCall.argument(2).isCall.argument.l) { + case (paramsAssign: Call) :: (constantize: FieldIdentifier) :: Nil => + paramsAssign.code shouldBe " = params[:type]" + inside(paramsAssign.argument.l) { case (tmpIdent: Identifier) :: (indexAccess: Call) :: Nil => + tmpIdent.name shouldBe "" + + indexAccess.name shouldBe Operators.indexAccess + indexAccess.code shouldBe "params[:type]" + } + + constantize.canonicalName shouldBe "constantize" + } case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]") } } @@ -218,7 +277,26 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { inside(cpg.call("src").l) { case src :: Nil => src.name shouldBe "src" - src.methodFullName shouldBe "Test0.rb:::program:src" + src.methodFullName shouldBe s"Test0.rb:$Main.src" + case xs => fail(s"Expected exactly one `src` call, instead got [${xs.code.mkString(",")}]") + } + } + } + + "a parenthesis-less call as the base of a member access" should { + val cpg = code(""" + |def f(p) + | src.join(",") + |end + | + |def src = [1, 2] + |""".stripMargin) + + "correctly create a `src` call instead of identifier" in { + inside(cpg.call("src").l) { + case src :: Nil => + src.name shouldBe "src" + src.methodFullName shouldBe s"Test0.rb:$Main.src" case xs => fail(s"Expected exactly one `src` call, instead got [${xs.code.mkString(",")}]") } } @@ -260,10 +338,26 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { inArg.argumentName shouldBe Option("in") } + "Calls with named arguments using symbols and hash rocket syntax" in { + val cpg = code("render :foo => \"bar\"") + val List(_, barArg: Literal) = cpg.call.nameExact("render").argument.l: @unchecked + barArg.code shouldBe "\"bar\"" + barArg.argumentName shouldBe Option("foo") + } + + "named parameters in parenthesis-less call with a known keyword as the association key should shadow the keyword" in { + val cpg = code(""" + |foo retry: 3 + |""".stripMargin) + val List(_, retry) = cpg.call.nameExact("foo").argument.l: @unchecked + retry.code shouldBe "3" + retry.argumentName shouldBe Some("retry") + } + "a call with a quoted regex literal should have a literal receiver" in { - val cpg = code("%r{^/}.freeze") + val cpg = code("%r{^/}.freeze()") val regexLiteral = cpg.call.nameExact("freeze").receiver.fieldAccess.argument(1).head.asInstanceOf[Literal] - regexLiteral.typeFullName shouldBe s"$kernelPrefix.Regexp" + regexLiteral.typeFullName shouldBe RubyDefines.prefixAsCoreType(RubyDefines.Regexp) regexLiteral.code shouldBe "%r{^/}" } @@ -271,14 +365,222 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val cpg = code("::Augeas.open { |aug| aug.get('/augeas/version') }") val augeasReceiv = cpg.call.nameExact("open").receiver.head.asInstanceOf[Call] augeasReceiv.methodFullName shouldBe Operators.fieldAccess - augeasReceiv.code shouldBe "::Augeas.open" + augeasReceiv.code shouldBe "( = ::Augeas).open" val selfAugeas = augeasReceiv.argument(1).asInstanceOf[Call] - selfAugeas.argument(1).asInstanceOf[Identifier].name shouldBe RubyDefines.Self - selfAugeas.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Augeas" + selfAugeas.argument(1).asInstanceOf[Identifier].name shouldBe "" + selfAugeas.argument(2).asInstanceOf[Call].code shouldBe "self::Augeas" augeasReceiv.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "open" } + "`nil` keyword as a member access should be a literal" in { + val cpg = code("nil.to_json") + val toJson = cpg.fieldAccess.codeExact("nil.to_json").head + val nilRec = toJson.argument(1).asInstanceOf[Literal] + + nilRec.code shouldBe "nil" + nilRec.lineNumber shouldBe Option(1) + } + + "Object initialize calls should be DynamicUnknown" in { + val cpg = code("""Date.new(2013, 19, 20)""") + + inside(cpg.call.name(RubyDefines.Initialize).l) { + case initCall :: Nil => + initCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName + case xs => fail(s"Expected one call to initialize, got ${xs.code.mkString}") + } + } + + "Member calls where the LHS is a call" should { + + "assign the first call to a temp variable to avoid a second invocation at arg 0" in { + val cpg = code("a().b()") + + val bCall = cpg.call("b").head + bCall.code shouldBe "( = a()).b()" + + // Check receiver + val bAccess = bCall.receiver.isCall.head + bAccess.name shouldBe Operators.fieldAccess + bAccess.methodFullName shouldBe Operators.fieldAccess + bAccess.code shouldBe "( = a()).b" + + bAccess.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "b" + + val aAssign = bAccess.argument(1).asInstanceOf[Call] + aAssign.name shouldBe Operators.assignment + aAssign.methodFullName shouldBe Operators.assignment + aAssign.code shouldBe " = a()" + + aAssign.argument(1).asInstanceOf[Identifier].name shouldBe "" + aAssign.argument(2).asInstanceOf[Call].name shouldBe "a" + + // Check (cached) base + val base = bCall.argument(0).asInstanceOf[Identifier] + base.name shouldBe "" + } + } + + "Call with Array Argument" in { + val cpg = code(""" + |def foo(a) + | puts a + |end + | + |foo([:b, :c => 1]) + |""".stripMargin) + + inside(cpg.call.name("foo").l) { + case fooCall :: Nil => + inside(fooCall.argument.l) { + case _ :: (arrayArg: Block) :: Nil => + arrayArg.code shouldBe "[:b, :c => 1]" + + inside(arrayArg.astChildren.l) { + case (_: Call) :: (elem1: Call) :: (elem2: Call) :: (_: Identifier) :: Nil => + elem1.code shouldBe "[0] = :b" + elem2.code shouldBe "[1] = :c => 1" + + elem1.methodFullName shouldBe Operators.assignment + elem2.methodFullName shouldBe Operators.assignment + elem2.argument(2).asInstanceOf[Call].methodFullName shouldBe RubyOperators.association + case xs => fail(s"Expected two args for elements, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected two args, got ${xs.map(x => x.label -> x.code).mkString(",")}") + } + case xs => fail(s"Expected one call for foo, got ${xs.code.mkString}") + } + } + + "Calls separated by `tmp` should render correct `code` properties" in { + val cpg = code(""" + |User.find_by(auth_token: cookies[:auth_token].to_s) + |""".stripMargin) + + cpg.call("find_by").code.head shouldBe "( = User).find_by(auth_token: cookies[:auth_token].to_s)" + cpg.call(Operators.indexAccess).code.head shouldBe "cookies[:auth_token]" + cpg.fieldAccess + .where(_.fieldIdentifier.canonicalNameExact("@to_s")) + .code + .head shouldBe "( = cookies[:auth_token]).to_s" + } + + "Calls with multiple splat args" in { + val cpg = code(""" + | doorkeeper_application&.includes_scope?( + | *::Gitlab::Auth::API_SCOPE, *::Gitlab::Auth::READ_API_SCOPE, + | *::Gitlab::Auth::ADMIN_SCOPES, *::Gitlab::Auth::REPOSITORY_SCOPES, + | *::Gitlab::Auth::REGISTRY_SCOPES + | ) + |""".stripMargin) + + inside(cpg.call.name("includes_scope\\?").argument.l) { + case _ :: (apiScopeSplat: Call) :: (readScopeSplat: Call) :: (adminScopeSplat: Call) :: (repoScopeSplat: Call) :: (registryScopeSplat: Call) :: Nil => + apiScopeSplat.code shouldBe "*::Gitlab::Auth::API_SCOPE" + apiScopeSplat.methodFullName shouldBe RubyOperators.splat + + readScopeSplat.code shouldBe "*::Gitlab::Auth::READ_API_SCOPE" + readScopeSplat.methodFullName shouldBe RubyOperators.splat + + adminScopeSplat.code shouldBe "*::Gitlab::Auth::ADMIN_SCOPES" + adminScopeSplat.methodFullName shouldBe RubyOperators.splat + + repoScopeSplat.code shouldBe "*::Gitlab::Auth::REPOSITORY_SCOPES" + repoScopeSplat.methodFullName shouldBe RubyOperators.splat + + registryScopeSplat.code shouldBe "*::Gitlab::Auth::REGISTRY_SCOPES" + registryScopeSplat.methodFullName shouldBe RubyOperators.splat + + case xs => fail(s"Expected 5 arguments for call, got [${xs.code.mkString(",")}]") + } + } + + "Multiple different arg types in a call" in { + val cpg = code(""" + |params.require(:issue).permit( + | "1234", + | 10, + | *issue_params_attributes, + | sentry_issue_attributes: [:sentry_issue_identifier] + | ) + |""".stripMargin) + + inside(cpg.call.name("permit").argument.l) { + case _ :: (strLiteral: Literal) :: (numericLiteral: Literal) :: (issueSplat: Call) :: (sentryAssoc: Block) :: Nil => + issueSplat.code shouldBe "*issue_params_attributes" + issueSplat.methodFullName shouldBe RubyOperators.splat + + sentryAssoc.code shouldBe "[:sentry_issue_identifier]" + + strLiteral.code shouldBe "\"1234\"" + strLiteral.typeFullName shouldBe RubyDefines.prefixAsCoreType(RubyDefines.String) + + numericLiteral.code shouldBe "10" + numericLiteral.typeFullName shouldBe RubyDefines.prefixAsCoreType(RubyDefines.Integer) + case xs => fail(s"Expected 6 parameters for call, got [${xs.code.mkString(", ")}]") + } + } + + "Call with association IndexAccess key" in { + val cpg = code(""" + |foo(bar[:baz] => nil) + |""".stripMargin) + + inside(cpg.call.name("foo").argument.l) { + case _ :: (assocParam: Call) :: Nil => + assocParam.methodFullName shouldBe RubyOperators.association + assocParam.code shouldBe "bar[:baz] => nil" + + inside(assocParam.argument.l) { + case (lhs: Call) :: (rhs: Literal) :: Nil => + lhs.methodFullName shouldBe Operators.indexAccess + lhs.code shouldBe "bar[:baz]" + + rhs.code shouldBe "nil" + case xs => fail(s"Expected lhs and rhs for association, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected two params, got [${xs.code.mkString(",")}]") + } + } + + "Call with association MemberAccess key" in { + val cpg = code(""" + |foo(bar.baz => nil) + |""".stripMargin) + + inside(cpg.call.name("foo").argument.l) { + case _ :: (assocParam: Call) :: Nil => + assocParam.methodFullName shouldBe RubyOperators.association + assocParam.code shouldBe "bar.baz => nil" + + inside(assocParam.argument.l) { + case (lhs: Call) :: (rhs: Literal) :: Nil => + lhs.methodFullName shouldBe Operators.fieldAccess + lhs.code shouldBe "bar.baz" + + rhs.code shouldBe "nil" + case xs => fail(s"Expected lhs and rhs for association, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected two params, got [${xs.code.mkString(",")}]") + } + } + + "A Set instantiation with a 'brackets' call" should { + val cpg = code("Set[]") + + "be a call with name '[]'" in { + inside(cpg.call.nameExact("[]").l) { case bracketCall :: Nil => + bracketCall.code shouldBe "( = Set).[]()" + bracketCall.name shouldBe "[]" + inside(bracketCall.argument.l) { case (tmpBase: Identifier) :: Nil => + tmpBase.name shouldBe "" + tmpBase.code shouldBe "" + tmpBase.argumentIndex shouldBe 0 + } + } + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala index 2d9ae6491306..e2c2483df42d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala @@ -1,9 +1,10 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment class CaseTests extends RubyCode2CpgFixture { @@ -20,16 +21,18 @@ class CaseTests extends RubyCode2CpgFixture { |""".stripMargin val cpg = code(caseCode) - val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l + val block @ List(_) = cpg.method.isModule.block.astChildren.isBlock.l - val List(assign) = block.astChildren.assignment.l; + val List(assign) = block.astChildren.collect { + case x: Call if x.name == Operators.assignment => x.asInstanceOf[Assignment] + }.l val List(lhs, rhs) = assign.argument.l List(lhs).isIdentifier.name.l shouldBe List("") List(rhs).isLiteral.code.l shouldBe List("0") val headIf @ List(_) = block.astChildren.isControlStructure.l - val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; + val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l val conds: List[List[String]] = ifStmts.condition.map { cond => val orConds = List(cond) .repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)( @@ -54,7 +57,6 @@ class CaseTests extends RubyCode2CpgFixture { // It's not ideal, but we choose the smallest containing text span that we have easily acesssible // as we don't have a good way to immutably update RubyNode text spans. - ifStmts.code.l should contain only caseCode.trim ifStmts.condition.map(_.code.trim).l shouldBe List("0", "when 1,2 then 1", "when 3, *[4,5] then 2", "*[6]") } @@ -68,7 +70,7 @@ class CaseTests extends RubyCode2CpgFixture { |end |""".stripMargin) - val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l + val block @ List(_) = cpg.method.isModule.block.astChildren.isBlock.l val headIf @ List(_) = block.astChildren.isControlStructure.l val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; @@ -98,4 +100,124 @@ class CaseTests extends RubyCode2CpgFixture { ifStmts.last.astChildren.order(3).l shouldBe List() } + + "An array pattern match statement" in { + val cpg = code(""" + |def self.class_for(type, location) + | case [type, location] + | in [:value, :path] + | ItemValuePathParsingError + | in [:label, :path] + | ItemLabelPathParsingError + | in [:label, :invalid] + | ItemLabelInvalidTypeParsingError + | else + | SomeOtherError + | end + |end""".stripMargin) + + val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l + + val assign = block.astChildren.assignment.head + val List(lhs, rhs) = assign.argument.l + + lhs.start.isIdentifier.name.l shouldBe List("") + rhs.start.isBlock.code.l shouldBe List("[type, location]") // array lowering + + val headIf @ List(_) = block.astChildren.isControlStructure.l + val ifStmts @ List(_, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; + + val conds: List[List[String]] = ifStmts.condition.map { cond => + val orConds = List(cond) + .repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)( + _.emit(_.whereNot(_.isCall.name(Operators.logicalOr))) + ) + .l + orConds.map { + case mExpr: Call if mExpr.name == "include?" => + val List(_, lhs, rhs) = mExpr.astChildren.l + rhs.code shouldBe "" + s"splat:${lhs.code}" + case mExpr: Call if mExpr.name == Operators.equals => + val List(lhs: Call, rhs) = mExpr.argument.l: @unchecked + rhs.code shouldBe "" + lhs.methodFullName shouldBe Operators.arrayInitializer + s"expr:${lhs.code}" + }.l + }.l + + conds shouldBe List(List("expr:[:value, :path]"), List("expr:[:label, :path]"), List("expr:[:label, :invalid]")) + } + + "An array pattern match statement with variables in the pattern" in { + val cpg = code(""" + |def self.class_for(type, location) + | case [type, location] + | in [:value, result] + | puts "#{result}" + | in [:label, notResult] + | puts "#{notResult}" + | else + | puts "else" + | end + |end""".stripMargin) + + val List(_, resultLocal, notResultLocal) = cpg.method.name("class_for").block.astChildren.isLocal.l + resultLocal.name shouldBe "result" + notResultLocal.name shouldBe "notResult" + + val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l + + val assign = block.astChildren.assignment.head + val List(lhs, rhs) = assign.argument.l + + lhs.start.isIdentifier.name.l shouldBe List("") + rhs.start.isBlock.code.l shouldBe List("[type, location]") // where the array lowering happens + + val headIf @ List(_) = block.astChildren.isControlStructure.l + val ifStmts @ List(_, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; + + val conds: List[List[String]] = ifStmts.condition.map { cond => + val orConds = List(cond) + .repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)( + _.emit(_.whereNot(_.isCall.name(Operators.logicalOr))) + ) + .l + orConds.map { + case mExpr: Call if mExpr.name == "include?" => + val List(_, lhs, rhs) = mExpr.astChildren.l + rhs.code shouldBe "" + s"splat:${lhs.code}" + case mExpr: Call if mExpr.name == Operators.equals => + val List(lhs: Call, rhs) = mExpr.argument.l: @unchecked + rhs.code shouldBe "" + lhs.methodFullName shouldBe Operators.arrayInitializer + s"expr:${lhs.code}" + }.l + }.l + + conds shouldBe List(List("expr:[:value, result]"), List("expr:[:label, notResult]")) + + inside(ifStmts.whenTrue.isBlock.astChildren.isCall.name(Operators.assignment).l) { + case resultMatchAssignment :: notResultMatchAssignment :: Nil => + inside(resultMatchAssignment.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.name shouldBe "result" + + rhs.methodFullName shouldBe Operators.indexAccess + rhs.code shouldBe s"[1]" + case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") + } + + inside(notResultMatchAssignment.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.name shouldBe "notResult" + + rhs.methodFullName shouldBe Operators.indexAccess + rhs.code shouldBe s"[1]" + case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]") + } + case _ => fail(s"Expected two true branches") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala index 9c03d330146d..3ad684c1c940 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala @@ -1,10 +1,11 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines.{Initialize, Main, TypeDeclBody} import io.joern.rubysrc2cpg.passes.{GlobalTypes, Defines as RubyDefines} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators} import io.shiftleft.semanticcpg.language.* class ClassTests extends RubyCode2CpgFixture { @@ -17,7 +18,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(classC) = cpg.typeDecl.name("C").l classC.inheritsFromTypeFullName shouldBe List() - classC.fullName shouldBe "Test0.rb:::program.C" + classC.fullName shouldBe s"Test0.rb:$Main.C" classC.lineNumber shouldBe Some(2) classC.baseType.l shouldBe List() classC.member.name.l shouldBe List(RubyDefines.TypeDeclBody, RubyDefines.Initialize) @@ -25,7 +26,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(singletonC) = cpg.typeDecl.nameExact("C").l singletonC.inheritsFromTypeFullName shouldBe List() - singletonC.fullName shouldBe "Test0.rb:::program.C" + singletonC.fullName shouldBe s"Test0.rb:$Main.C" singletonC.lineNumber shouldBe Some(2) singletonC.baseType.l shouldBe List() singletonC.member.name.l shouldBe List() @@ -42,7 +43,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(classC) = cpg.typeDecl.name("C").l classC.inheritsFromTypeFullName shouldBe List("D") - classC.fullName shouldBe "Test0.rb:::program.C" + classC.fullName shouldBe s"Test0.rb:$Main.C" classC.lineNumber shouldBe Some(2) classC.member.name.l shouldBe List(RubyDefines.TypeDeclBody, RubyDefines.Initialize) classC.method.name.l shouldBe List(RubyDefines.TypeDeclBody, RubyDefines.Initialize) @@ -53,7 +54,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(singletonC) = cpg.typeDecl.nameExact("C").l singletonC.inheritsFromTypeFullName shouldBe List("D") - singletonC.fullName shouldBe "Test0.rb:::program.C" + singletonC.fullName shouldBe s"Test0.rb:$Main.C" singletonC.lineNumber shouldBe Some(2) singletonC.member.name.l shouldBe List() singletonC.method.name.l shouldBe List() @@ -74,6 +75,9 @@ class ClassTests extends RubyCode2CpgFixture { val List(singletonC) = cpg.typeDecl.name("C").l singletonC.member.nameExact("@a").isEmpty shouldBe true + + val List(aGetterMember) = classC.member.nameExact("a").l + aGetterMember.dynamicTypeHintFullName should contain("Test0.rb:
.C.a") } "`attr_reader :'abc'` is represented by a `@abc` MEMBER node" in { @@ -88,6 +92,9 @@ class ClassTests extends RubyCode2CpgFixture { abcMember.code shouldBe "attr_reader :'abc'" abcMember.lineNumber shouldBe Some(3) + + val List(aMember) = classC.member.nameExact("abc").l + aMember.dynamicTypeHintFullName should contain("Test0.rb:
.C.abc") } "`attr_reader :'abc' creates an `abc` METHOD node" in { @@ -102,14 +109,17 @@ class ClassTests extends RubyCode2CpgFixture { methodAbc.code shouldBe "def abc (...)" methodAbc.lineNumber shouldBe Some(3) - methodAbc.parameter.isEmpty shouldBe true - methodAbc.fullName shouldBe "Test0.rb:::program.C:abc" + methodAbc.parameter.indexGt(0).isEmpty shouldBe true + methodAbc.fullName shouldBe s"Test0.rb:$Main.C.abc" - // TODO: Make sure that @abc in this return is the actual field val List(ret: Return) = methodAbc.methodReturn.cfgIn.l: @unchecked - val List(abcField: Identifier) = ret.astChildren.l: @unchecked - ret.code shouldBe "return @abc" - abcField.name shouldBe "@abc" + val List(abcFieldAccess: Call) = ret.astChildren.l: @unchecked + ret.code shouldBe "@abc" + abcFieldAccess.name shouldBe Operators.fieldAccess + abcFieldAccess.code shouldBe "self.@abc" + + val List(aMember) = classC.member.nameExact("abc").l + aMember.dynamicTypeHintFullName should contain("Test0.rb:
.C.abc") } "`attr_reader :a, :b` is represented by `@a`, `@b` MEMBER nodes" in { @@ -152,15 +162,19 @@ class ClassTests extends RubyCode2CpgFixture { methodA.code shouldBe "def a= (...)" methodA.lineNumber shouldBe Some(3) - methodA.fullName shouldBe "Test0.rb:::program.C:a=" + methodA.fullName shouldBe s"Test0.rb:$Main.C.a=" // TODO: there's probably a better way for testing this - val List(param) = methodA.parameter.l - val List(assignment) = methodA.assignment.l - val List(lhs: Identifier, rhs: Identifier) = assignment.argument.l: @unchecked + val List(_, param) = methodA.parameter.l + val List(assignment) = methodA.assignment.l + val List(lhs: Call, rhs: Identifier) = assignment.argument.l: @unchecked param.name shouldBe rhs.name - lhs.name shouldBe "@a" + lhs.name shouldBe Operators.fieldAccess + lhs.code shouldBe "self.@a" + + val List(aMember) = classC.member.nameExact("a=").l + aMember.dynamicTypeHintFullName should contain("Test0.rb:
.C.a=") } "`attr_accessor :a` is represented by a `@a` MEMBER node" in { @@ -177,6 +191,29 @@ class ClassTests extends RubyCode2CpgFixture { aMember.lineNumber shouldBe Some(3) } + "`attr_reader` in a nested class should generate the correct member for the correct class" in { + val cpg = code(""" + |class Foo + | + | class Bar + | attr_reader :a + | end + | + |end + | + |""".stripMargin) + + val List(bar) = cpg.typeDecl.name("Bar").l + val List(aMember) = bar.member.name("@a").l + + aMember.code shouldBe "attr_reader :a" + aMember.lineNumber shouldBe Some(5) + + // calls are pushed all the way up to the + cpg.typeDecl.astOut.isCall.size shouldBe 0 + cpg.call(RubyDefines.TypeDeclBody).method.dedup.name.l shouldBe List(RubyDefines.Main) + } + "`def f(x) ... end` is represented by a METHOD inside the TYPE_DECL node" in { val cpg = code(""" |class C @@ -189,7 +226,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(classC) = cpg.typeDecl.name("C").l val List(methodF) = classC.method.name("f").l - methodF.fullName shouldBe "Test0.rb:::program.C:f" + methodF.fullName shouldBe s"Test0.rb:$Main.C.f" val List(memberF) = classC.member.nameExact("f").l memberF.dynamicTypeHintFullName.toSet should contain(methodF.fullName) @@ -261,7 +298,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(classC) = cpg.typeDecl.name("C").l val List(methodInit) = classC.method.name(RubyDefines.Initialize).l - methodInit.fullName shouldBe s"Test0.rb:::program.C:${RubyDefines.Initialize}" + methodInit.fullName shouldBe s"Test0.rb:$Main.C.${RubyDefines.Initialize}" methodInit.isConstructor.isEmpty shouldBe false } @@ -274,7 +311,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(classC) = cpg.typeDecl.name("C").l val List(methodInit) = classC.method.name(RubyDefines.Initialize).l - methodInit.fullName shouldBe s"Test0.rb:::program.C:${RubyDefines.Initialize}" + methodInit.fullName shouldBe s"Test0.rb:$Main.C.${RubyDefines.Initialize}" } "only `def initialize() ... end` directly under class has the constructor modifier" in { @@ -312,8 +349,8 @@ class ClassTests extends RubyCode2CpgFixture { | |""".stripMargin) - cpg.member("MConst").typeDecl.fullName.head shouldBe "Test0.rb:::program.MMM" - cpg.member("NConst").typeDecl.fullName.head shouldBe "Test0.rb:::program.MMM.Nested" + cpg.member("MConst").typeDecl.fullName.head shouldBe s"Test0.rb:$Main.MMM" + cpg.member("NConst").typeDecl.fullName.head shouldBe s"Test0.rb:$Main.MMM.Nested" } "a basic anonymous class" should { @@ -329,14 +366,14 @@ class ClassTests extends RubyCode2CpgFixture { inside(cpg.typeDecl.nameExact("").l) { case anonClass :: Nil => anonClass.name shouldBe "" - anonClass.fullName shouldBe "Test0.rb:::program." + anonClass.fullName shouldBe s"Test0.rb:$Main." inside(anonClass.method.l) { case hello :: defaultConstructor :: Nil => defaultConstructor.name shouldBe RubyDefines.Initialize - defaultConstructor.fullName shouldBe s"Test0.rb:::program.:${RubyDefines.Initialize}" + defaultConstructor.fullName shouldBe s"Test0.rb:$Main..${RubyDefines.Initialize}" hello.name shouldBe "hello" - hello.fullName shouldBe "Test0.rb:::program.:hello" + hello.fullName shouldBe s"Test0.rb:$Main..hello" case xs => fail(s"Expected a single method, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") } case xs => fail(s"Expected a single anonymous class, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") @@ -344,18 +381,20 @@ class ClassTests extends RubyCode2CpgFixture { } "generate an assignment to the variable `a` with the source being a constructor invocation of the class" in { - inside(cpg.method(":program").assignment.l) { - case aAssignment :: Nil => + inside(cpg.method.isModule.assignment.l) { + case aAssignment :: tmpAssign :: Nil => aAssignment.target.code shouldBe "a" - aAssignment.source.code shouldBe "Class.new (...)" + aAssignment.source.code shouldBe "( = Class.new (...)).new" + + tmpAssign.target.code shouldBe "" + tmpAssign.source.code shouldBe "self.Class.new (...)" case xs => fail(s"Expected a single assignment, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") } } } - // TODO: This should be remodelled as a property access `animal.bark = METHOD_REF` - "a basic singleton class" ignore { + "a basic singleton class extending an object instance" should { val cpg = code("""class Animal; end |animal = Animal.new | @@ -363,35 +402,49 @@ class ClassTests extends RubyCode2CpgFixture { | def bark | 'Woof' | end + | + | def legs + | 4 + | end |end | |animal.bark # => 'Woof' |""".stripMargin) - "generate a type decl with the associated members" in { - inside(cpg.typeDecl.nameExact("").l) { - case anonClass :: Nil => - anonClass.name shouldBe "" - anonClass.fullName shouldBe "Test0.rb:::program." - // TODO: Attempt to resolve the below with the `scope` class once we're handling constructors - anonClass.inheritsFromTypeFullName shouldBe Seq("animal") - inside(anonClass.method.l) { - case defaultConstructor :: bark :: Nil => - defaultConstructor.name shouldBe Defines.ConstructorMethodName - defaultConstructor.fullName shouldBe s"Test0.rb:::program.:${Defines.ConstructorMethodName}" + "Create assignments to method refs for methods on singleton object" in { + inside(cpg.method.isModule.block.assignment.l) { + case _ :: _ :: _ :: barkAssignment :: legsAssignment :: Nil => + inside(barkAssignment.argument.l) { + case (lhs: Call) :: (rhs: TypeRef) :: Nil => + val List(identifier, fieldIdentifier) = lhs.argument.l: @unchecked + identifier.code shouldBe "animal" + fieldIdentifier.code shouldBe "bark" + + rhs.typeFullName shouldBe s"Test0.rb:$Main.class< fail(s"Expected two arguments for assignment, got [${xs.code.mkString(",")}]") + } - bark.name shouldBe "bark" - bark.fullName shouldBe "Test0.rb:::program.:bark" - case xs => fail(s"Expected a single method, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") + inside(legsAssignment.argument.l) { + case (lhs: Call) :: (rhs: TypeRef) :: Nil => + val List(identifier, fieldIdentifier) = lhs.argument.l: @unchecked + identifier.code shouldBe "animal" + fieldIdentifier.code shouldBe "legs" + + rhs.typeFullName shouldBe s"Test0.rb:$Main.class< fail(s"Expected two arguments for assignment, got [${xs.code.mkString(",")}]") } - case xs => fail(s"Expected a single anonymous class, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") + case xs => fail(s"Expected five assignments, got [${xs.code.mkString(",")}]") } } - "register that `animal` may possibly be an instantiation of the singleton type" in { - cpg.local("animal").possibleTypes.l should contain("Test0.rb:::program.") + "Create TYPE_DECL nodes for two singleton methods" in { + inside(cpg.typeDecl.name("(bark|legs)").l) { + case barkTypeDecl :: legsTypeDecl :: Nil => + barkTypeDecl.fullName shouldBe s"Test0.rb:$Main.class< fail(s"Expected two type_decls, got [${xs.code.mkString(",")}]") + } } - } "if: as function param" should { @@ -414,7 +467,7 @@ class ClassTests extends RubyCode2CpgFixture { val List(validateCall: Call) = methodBlock.astChildren.isCall.l: @unchecked inside(validateCall.argument.l) { - case (identArg: Identifier) :: (passwordArg: Literal) :: (presenceArg: Literal) :: (confirmationArg: Literal) :: (lengthArg: Block) :: (onArg: Literal) :: (ifArg: Literal) :: Nil => + case (identArg: Identifier) :: (passwordArg: Literal) :: (presenceArg: Literal) :: (confirmationArg: Literal) :: (_: Block) :: (onArg: Literal) :: (ifArg: Literal) :: Nil => passwordArg.code shouldBe ":password" presenceArg.code shouldBe "true" confirmationArg.code shouldBe "true" @@ -434,8 +487,8 @@ class ClassTests extends RubyCode2CpgFixture { val cpg = code(""" | class AdminController < ApplicationController | before_action :administrative, if: :admin_param, except: [:get_user] - | skip_before_action :has_info - | layout false, only: [:get_all_users, :get_user] + | skip_before_action :has_info + | layout false, only: [:get_all_users, :get_user] | end |""".stripMargin) @@ -466,12 +519,14 @@ class ClassTests extends RubyCode2CpgFixture { } } - "fully qualified base types" should { + "base types names extending a class in the definition" should { val cpg = code("""require "rails/all" | |module Bar - | class Baz + | module Baz + | class Boz + | end | end |end | @@ -479,12 +534,12 @@ class ClassTests extends RubyCode2CpgFixture { | class Application < Rails::Application | end | - | class Foo < Bar::Baz + | class Foo < Bar::Baz::Boz | end |end |""".stripMargin) - "not confuse the internal `Application` with `Rails::Application` and leave the type unresolved" in { + "handle a qualified base type from an external type correctly" in { inside(cpg.typeDecl("Application").headOption) { case Some(app) => app.inheritsFromTypeFullName.head shouldBe "Rails.Application" @@ -492,10 +547,10 @@ class ClassTests extends RubyCode2CpgFixture { } } - "resolve the internal type being referenced" in { + "handle a deeply qualified internal base type correctly" in { inside(cpg.typeDecl("Foo").headOption) { case Some(app) => - app.inheritsFromTypeFullName.head shouldBe "Test0.rb:::program.Bar.Baz" + app.inheritsFromTypeFullName.head shouldBe "Bar.Baz.Boz" case None => fail("Expected a type decl for 'Foo', instead got nothing") } } @@ -576,6 +631,18 @@ class ClassTests extends RubyCode2CpgFixture { case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}") } } + + "call the body method" in { + inside(cpg.call.nameExact(RubyDefines.TypeDeclBody).headOption) { + case Some(bodyCall) => + bodyCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + bodyCall.methodFullName shouldBe s"Test0.rb:$Main.Foo.${RubyDefines.TypeDeclBody}" + bodyCall.code shouldBe "( = self::Foo)::()" + bodyCall.receiver.isEmpty shouldBe true + bodyCall.argument(0).code shouldBe "" + case None => fail("Expected call") + } + } } "Class Variables in Class and Methods" should { @@ -611,7 +678,7 @@ class ClassTests extends RubyCode2CpgFixture { cMember.code shouldBe "@@c" dMember.code shouldBe "@@d" oMember.code shouldBe "@@o" - case _ => fail("Expected 5 members") + case xs => fail(s"Expected 5 members, instead got ${xs.size}: [${xs.code.mkString(",")}]") } case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}") } @@ -625,7 +692,6 @@ class ClassTests extends RubyCode2CpgFixture { inside(clinitMethod.block.astChildren.isCall.name(Operators.assignment).l) { case aAssignment :: bAssignment :: cAssignment :: dAssignment :: oAssignment :: Nil => aAssignment.code shouldBe "@@a = nil" - bAssignment.code shouldBe "@@b = nil" cAssignment.code shouldBe "@@c = nil" dAssignment.code shouldBe "@@d = nil" @@ -646,9 +712,10 @@ class ClassTests extends RubyCode2CpgFixture { rhs.code shouldBe "nil" case _ => fail("Expected only LHS and RHS for assignment call") } - case _ => fail("") + case xs => + fail(s"Expected 5 fields initializers, got ${xs.size} instead ${xs.code.mkString(", ")}") } - case xs => fail(s"Expected one method for clinit, instead got ${xs.name.mkString(", ")}") + case xs => fail(s"Expected one method for , instead got ${xs.name.mkString(", ")}") } case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}") } @@ -680,7 +747,7 @@ class ClassTests extends RubyCode2CpgFixture { "create the `StandardError` local variable" in { cpg.local.nameExact("some_variable").dynamicTypeHintFullName.toList shouldBe List( - s"<${GlobalTypes.builtinPrefix}.StandardError>" + RubyDefines.prefixAsCoreType("StandardError") ) } @@ -738,18 +805,25 @@ class ClassTests extends RubyCode2CpgFixture { case fooClass :: Nil => inside(fooClass.method.name(RubyDefines.TypeDeclBody).l) { case initMethod :: Nil => - initMethod.code shouldBe "def \nscope :hits_by_ip, ->(ip, col = \"*\") { select(\"#{col}\").where(ip_address: ip).order(\"id DESC\") }\nend" + initMethod.code shouldBe "def ; (...); end" inside(initMethod.astChildren.isBlock.l) { case methodBlock :: Nil => inside(methodBlock.astChildren.l) { case methodCall :: Nil => inside(methodCall.astChildren.l) { - case (base: Call) :: (self: Identifier) :: (literal: Literal) :: (methodRef: MethodRef) :: Nil => + case (base: Call) :: (self: Identifier) :: (literal: Literal) :: (typeRef: TypeRef) :: Nil => base.code shouldBe "self.scope" self.name shouldBe "self" literal.code shouldBe ":hits_by_ip" - methodRef.methodFullName shouldBe s"Test0.rb:::program.Foo:${RubyDefines.TypeDeclBody}:0" - methodRef.referencedMethod.parameter.indexGt(0).name.l shouldBe List("ip", "col") + typeRef.typeFullName shouldBe s"Test0.rb:$Main.Foo.${RubyDefines.TypeDeclBody}.0&Proc" + cpg.method + .fullNameExact( + typeRef.typ.referencedTypeDecl.member.name("call").dynamicTypeHintFullName.toSeq* + ) + .parameter + .indexGt(0) + .name + .l shouldBe List("ip", "col") case xs => fail(s"Expected three children, got ${xs.code.mkString(", ")} instead") } case xs => fail(s"Expected one call, got ${xs.code.mkString(", ")} instead") @@ -776,11 +850,488 @@ class ClassTests extends RubyCode2CpgFixture { case assignCall :: Nil => inside(assignCall.argument.l) { case lhs :: (rhs: Call) :: Nil => - rhs.typeFullName shouldBe "<__builtin.Encoding.Converter>:asciicompat_encoding" + rhs.typeFullName shouldBe "__builtin.Encoding.Converter.asciicompat_encoding" case xs => fail(s"Expected lhs and rhs for assignment call, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected one call for assignment, got [${xs.code.mkString(",")}]") } } } + + "Class definition on one line" should { + val cpg = code(""" + |class X 1 end + |""".stripMargin) + + "create TYPE_DECL" in { + inside(cpg.typeDecl.name("X").l) { + case xClass :: Nil => + inside(xClass.astChildren.isMethod.l) { + case bodyMethod :: initMethod :: Nil => + inside(bodyMethod.block.astChildren.l) { + case (literal: Literal) :: Nil => + literal.code shouldBe "1" + case xs => fail(s"Expected literal for body method, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected body and init method, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one class, got [${xs.code.mkString(",")}]") + } + } + } + + "A call to super" should { + val cpg = code(""" + |class A + | def foo(a) + | end + |end + |class B < A + | def foo(a) + | super(a) + | end + |end + |""".stripMargin) + + "create a simple call" in { + val superCall = cpg.call.nameExact("super").head + superCall.code shouldBe "super(a)" + superCall.name shouldBe "super" + superCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName + } + } + + "a class that is redefined should have a counter suffixed to ensure uniqueness" in { + val cpg = code(""" + |class Foo + | def foo;end + |end + |class Bar;end + |class Foo + | def foo;end + |end + |class Foo;end + |""".stripMargin) + + cpg.typeDecl.name("(Foo|Bar).*").filterNot(_.name.endsWith("")).name.l shouldBe List( + "Foo", + "Bar", + "Foo", + "Foo" + ) + cpg.typeDecl.name("(Foo|Bar).*").filterNot(_.name.endsWith("")).fullName.l shouldBe List( + s"Test0.rb:$Main.Foo", + s"Test0.rb:$Main.Bar", + s"Test0.rb:$Main.Foo0", + s"Test0.rb:$Main.Foo1" + ) + + cpg.method.nameExact("foo").fullName.l shouldBe List(s"Test0.rb:$Main.Foo.foo", s"Test0.rb:$Main.Foo0.foo") + + } + + "Class with nonAllowedTypeDeclChildren and explicit init" should { + val cpg = code(""" + |class Foo + | 1 + | def initialize(bar) + | puts bar + | end + |end + |""".stripMargin) + + "have an explicit init method" in { + inside(cpg.typeDecl.nameExact("Foo").method.l) { + case bodyMethod :: initMethod :: Nil => + bodyMethod.name shouldBe TypeDeclBody + + initMethod.name shouldBe Initialize + inside(initMethod.parameter.l) { + case selfParam :: barParam :: Nil => + selfParam.name shouldBe "self" + barParam.name shouldBe "bar" + case xs => fail(s"Expected two params, got [${xs.code.mkString(",")}]") + } + + inside(initMethod.block.astChildren.l) { + case (putsCall: Call) :: Nil => + putsCall.name shouldBe "puts" + case xs => fail(s"Expected one call, got [${xs.code.mkString(",")}]") + } + + inside(bodyMethod.block.astChildren.l) { + case (one: Literal) :: Nil => + one.code shouldBe "1" + one.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") + case xs => fail(s"Expected one literal, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected body method and init method, got [${xs.code.mkString(",")}]") + } + } + } + + "Class defined in Namespace" in { + val cpg = code(""" + |class Api::V1::MobileController + |end + |""".stripMargin) + + inside(cpg.namespaceBlock.fullNameExact("Api.V1").typeDecl.l) { + case mobileNamespace :: mobileClassNamespace :: Nil => + mobileNamespace.name shouldBe "MobileController" + mobileNamespace.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + + mobileClassNamespace.name shouldBe "MobileController" + mobileClassNamespace.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + case xs => fail(s"Expected two namespace blocks, got ${xs.code.mkString(",")}") + } + + inside(cpg.typeDecl.name("MobileController").l) { + case mobileTypeDecl :: Nil => + mobileTypeDecl.name shouldBe "MobileController" + mobileTypeDecl.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + mobileTypeDecl.astParentFullName shouldBe "Api.V1" + mobileTypeDecl.astParentType shouldBe NodeTypes.NAMESPACE_BLOCK + + mobileTypeDecl.astParent.isNamespaceBlock shouldBe true + + val namespaceDecl = mobileTypeDecl.astParent.asInstanceOf[NamespaceBlock] + namespaceDecl.name shouldBe "Api.V1" + namespaceDecl.filename shouldBe "Test0.rb" + + namespaceDecl.astParent.isFile shouldBe true + val parentFileDecl = namespaceDecl.astParent.asInstanceOf[File] + parentFileDecl.name shouldBe "Test0.rb" + + case xs => fail(s"Expected one class decl, got [${xs.code.mkString(",")}]") + } + } + + "Namespace scope is popping properly" in { + val cpg = code(""" + |class Foo::Bar + |end + | + |class Baz + |end + |""".stripMargin) + + inside(cpg.typeDecl.name("Baz").l) { + case bazTypeDecl :: Nil => + bazTypeDecl.fullName shouldBe "Test0.rb:
.Baz" + case xs => fail(s"Expected one type decl, got [${xs.code.mkString(",")}]") + } + } + + "Self param in static method" in { + val cpg = code(""" + |class Benefits < ApplicationRecord + |def self.save(file, backup = false) + | data_path = Rails.root.join("public", "data") + | full_file_name = "#{data_path}/#{file.original_filename}" + | f = File.open(full_file_name, "wb+") + | f.write file.read + | f.close + | make_backup(file, data_path, full_file_name) if backup == "true" + |end + |end + |""".stripMargin) + } + + "Splat Field Declaration" in { + val cpg = code(""" + | class EpisodeRssItem + | FOUND = %i[title itunes_subtitle].freeze + | attr_reader(*FOUND) + | attr_reader(*NOT_FOUND) + | end + | + |""".stripMargin) + + val List(titleMethod) = cpg.method.name("title").l + val List(itunesMethod) = cpg.method.name("itunes_subtitle").l + val List(bodyMethod) = cpg.method.name("").l + + inside(titleMethod.methodReturn.toReturn.l) { + case methodReturn :: Nil => + methodReturn.code shouldBe "@title" + case xs => fail(s"Expected one return, got [${xs.code.mkString(",")}]") + } + + inside(itunesMethod.methodReturn.toReturn.l) { + case methodReturn :: Nil => + methodReturn.code shouldBe "@itunes_subtitle" + case xs => fail(s"Expected one return, got [${xs.code.mkString(",")}]") + } + + inside(bodyMethod.call.name("attr_reader").l) { + case notFoundCall :: Nil => + notFoundCall.code shouldBe "attr_reader(*NOT_FOUND)" + inside(notFoundCall.argument.l) { + case _ :: splatArg :: Nil => + splatArg.code shouldBe "*NOT_FOUND" + case xs => fail(s"Expected two args, got ${xs.size}: [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one call, got [${xs.code.mkString(",")}]") + } + } + + "Unknown Splat Field Declaration" in { + val cpg = code(""" + | class EpisodeRssItem + | attr_reader(*NOT_FOUND) + | end + |""".stripMargin) + + val List(bodyMethod) = cpg.method.name("").l + + inside(bodyMethod.call.name("attr_reader").l) { + case notFoundCall :: Nil => + notFoundCall.code shouldBe "attr_reader(*NOT_FOUND)" + inside(notFoundCall.argument.l) { + case _ :: splatArg :: Nil => + splatArg.code shouldBe "*NOT_FOUND" + case xs => fail(s"Expected two args, got ${xs.size}: [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one call, got [${xs.code.mkString(",")}]") + } + } + + "FieldsDeclaration in `included` block" should { + val cpg = code(""" + |class Foo + | included do + | before_update :validate_workflows + | attr_accessor :bar + | end + |end + |""".stripMargin) + + "Create required getters and setters directly under TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isMethod.name("bar=?").l) { + case barGetter :: barSetter :: Nil => + barGetter.name shouldBe "bar" + barGetter.fullName shouldBe "Test0.rb:
.Foo.bar" + + barSetter.name shouldBe "bar=" + barSetter.fullName shouldBe "Test0.rb:
.Foo.bar=" + case xs => fail(s"Expected two method calls for getter and setter, got [${xs.code.mkString(",")}]") + } + } + + "Create required TYPE_DECL nodes directly under class TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isTypeDecl.name("bar=?").l) { + case barGetter :: barSetter :: Nil => + barGetter.name shouldBe "bar" + barSetter.name shouldBe "bar=" + case xs => fail(s"Expected two type decls, got [${xs.code.mkString(",")}]") + } + } + + "Create required MEMBER nodes directly under class TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isMember.name("bar=?").l) { + case barGetter :: barSetter :: Nil => + barGetter.name shouldBe "bar" + barSetter.name shouldBe "bar=" + case xs => fail(s"Expected two member nodes, got [${xs.code.mkString(",")}]") + } + } + } + + "Multiple FieldsDeclaration in included" should { + val cpg = code(""" + |class Foo + | included do + | attr_accessor :bar + | attr_reader :baz + | end + |end + |""".stripMargin) + + "Create required getters and setters directly under TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isMethod.name("(bar=?|baz)").l) { + case barGetter :: barSetter :: bazGetter :: Nil => + barGetter.name shouldBe "bar" + barGetter.fullName shouldBe "Test0.rb:
.Foo.bar" + + barSetter.name shouldBe "bar=" + barSetter.fullName shouldBe "Test0.rb:
.Foo.bar=" + + bazGetter.name shouldBe "baz" + bazGetter.fullName shouldBe "Test0.rb:
.Foo.baz" + case xs => fail(s"Expected three method defs for getter and setter, got [${xs.code.mkString(",")}]") + } + } + + "Create required TYPE_DECL nodes directly under class TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isTypeDecl.name("(bar=?|baz)").l) { + case barGetter :: barSetter :: bazGetter :: Nil => + barGetter.name shouldBe "bar" + barSetter.name shouldBe "bar=" + bazGetter.name shouldBe "baz" + case xs => fail(s"Expected two type decls, got [${xs.code.mkString(",")}]") + } + } + + "Create required MEMBER nodes directly under class TYPE_DECL" in { + inside(cpg.typeDecl.name("Foo").astChildren.isMember.name("(bar=?|baz)").l) { + case barGetter :: barSetter :: bazGetter :: Nil => + barGetter.name shouldBe "bar" + barSetter.name shouldBe "bar=" + bazGetter.name shouldBe "baz" + case xs => fail(s"Expected two member nodes, got [${xs.code.mkString(",")}]") + } + } + } + + "If Statement in class declaration with alias and method def" should { + val cpg = code(""" + |module Pessimistic + | if !method_defined?(:orig_lock!) + | alias orig_lock! lock! + | + | def lock!(lock = true) # rubocop:disable Style/OptionalBooleanParameter + | orig_lock!(lock) + | end + | + | end + |end + | + |""".stripMargin) + + "Lower Alias to directly under TYPE_DECL" in { + inside(cpg.typeDecl.name("Pessimistic").astChildren.isMethod.name("lock!").l) { + case lockAliasMethodDef :: _ :: Nil => + lockAliasMethodDef.name shouldBe "lock!" + + val List(_, args, blockArg) = lockAliasMethodDef.parameter.l + args.code shouldBe "*args" + blockArg.code shouldBe "&block" + + inside(lockAliasMethodDef.body.astChildren.isReturn.astChildren.isCall.l) { + case origLockCall :: Nil => + origLockCall.name shouldBe "orig_lock!" + origLockCall.code shouldBe "orig_lock!(*args, &block)" + case xs => fail(s"Expected one call, got ${xs.size}: [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected two method defs, got ${xs.size}: [${xs.code.mkString(",")}]") + } + } + + "Lower method def to directly under TYPE_DECL" in { + inside(cpg.typeDecl.name("Pessimistic").astChildren.isMethod.name("lock!").l) { + case _ :: lockMethodDef :: Nil => + lockMethodDef.name shouldBe "lock!" + + val List(_, lockArg) = lockMethodDef.parameter.l + lockArg.code shouldBe "lock = true" + + inside(lockMethodDef.body.astChildren.isReturn.astChildren.isCall.l) { + case origLockCall :: Nil => + origLockCall.name shouldBe "orig_lock!" + origLockCall.code shouldBe "orig_lock!(lock)" + case xs => fail(s"Expected one call, got ${xs.size}: [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected two method defs, got ${xs.size}: [${xs.code.mkString(",")}]") + } + } + } + + "Implicit return for call to `private_class_method`" in { + val cpg = code(""" + |class Foo + | def case_sensitive_find_by() + | end + | + | included do + | private_class_method :case_sensitive_find_by + | end + |end + |""".stripMargin) + + inside(cpg.typeDecl.name("Foo").astChildren.isMethod.l) { + case lambdaMethod :: _ :: _ :: _ :: Nil => + val List(lambdaReturn) = lambdaMethod.body.astChildren.isReturn.l + + lambdaReturn.code shouldBe "private_class_method :case_sensitive_find_by" + + val List(returnCall) = lambdaReturn.astChildren.isCall.l + returnCall.code shouldBe "private_class_method :case_sensitive_find_by" + + val List(_, methodNameArg) = returnCall.argument.l + methodNameArg.code shouldBe "self.:case_sensitive_find_by" + + case xs => fail(s"Expected 5 methods, got [${xs.code.mkString(",")}]") + } + } + + "Implicit return of SingletonClassDeclaration" in { + val cpg = code(""" + |module Taskbar::List + | + | included do + | class << self + | def trigger_list_update(user, app) + | end + | end + | end + |end + |""".stripMargin) + inside(cpg.typeDecl.name("List").l) { + case listTypeDecl :: Nil => + val List(lambdaMethod) = listTypeDecl.astChildren.isMethod.isLambda.l + + val List(lambdaReturn) = lambdaMethod.astChildren.isBlock.astChildren.isReturn.l + lambdaReturn.code shouldBe "return nil" + + val List(lambdaTypeDecl, lambdaTypeDeclClass) = lambdaMethod.astChildren.isTypeDecl.l + lambdaTypeDecl.name shouldBe "" + lambdaTypeDeclClass.name shouldBe "" + + case xs => fail(s"expected 1 type, got ${xs.size}: [${xs.code.mkString(", ")}]") + } + } + + "`private_class_method` with unhandled argument types should not render these under the type decl" in { + val cpg = code(""" + |class Foo + | + | private_class_method %i[ + | filter_ignore_usable + | filter_key + | filter_usage + | parts + | ] + | + |end + | + |""".stripMargin) + + cpg.typeDecl("Foo").astChildren.whereNot(_.or(_.isMethod, _.isModifier, _.isTypeDecl, _.isMember)).size shouldBe 0 + } + + "Proc-param in method with instance field assignment and instance field argument" in { + val cpg = code(""" + |class Batches + | def as_batches(query, &) + | records.each(&) + | @limit -= 100 + | return if @limit.zero? + | end + |end + |""".stripMargin) + + inside(cpg.method.name("as_batches").l) { + case batchesMethod :: Nil => + inside(batchesMethod.parameter.l) { + case _ :: _ :: procParam :: Nil => + procParam.code shouldBe "&" + procParam.name shouldBe "" + case xs => fail(s"Expected three parameters, got (${xs.size}) [${xs.code.mkString(",")}]") + } + case xs => + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala index 3ed14c936954..96fbd50dcb1d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala @@ -1,10 +1,9 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Local} +import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Local} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} import io.shiftleft.semanticcpg.language.* -import io.shiftleft.codepropertygraph.generated.nodes.Call class ConditionalTests extends RubyCode2CpgFixture { @@ -47,10 +46,9 @@ class ConditionalTests extends RubyCode2CpgFixture { inside(cpg.call(Operators.conditional).l) { case cond :: Nil => inside(cond.argument.l) { - case x :: y :: z :: Nil => { + case x :: y :: z :: Nil => x.code shouldBe "x" List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z") - } case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]") @@ -61,16 +59,22 @@ class ConditionalTests extends RubyCode2CpgFixture { val cpg = code("""x, y, z = false, true, false |f(unless x then y else z end) |""".stripMargin) - inside(cpg.call(Operators.conditional).l) { - case cond :: Nil => - inside(cond.argument.l) { - case x :: y :: z :: Nil => { - List(x).isCall.name(Operators.logicalNot).argument.code.l shouldBe List("x") - List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z") - } - case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]") + inside(cpg.call.nameExact(Operators.conditional).l) { + case conditionalCall :: Nil => + conditionalCall.code shouldBe "unless x then y else z end" + inside(conditionalCall.argument.l) { + case condition :: (trueBranch: Block) :: (falseBranch: Block) :: Nil => + condition.code shouldBe "x" + + val List(trueBranchIdent) = trueBranch.astChildren.isIdentifier.l + trueBranchIdent.code shouldBe "z" + + val List(falseBranchIdent) = falseBranch.astChildren.isIdentifier.l + falseBranchIdent.code shouldBe "y" + + case xs => fail(s"Expected three arguments for conditional call, got [${xs.code.mkString(",")}]") } - case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]") + case xs => fail(s"Expected one call to conditional, got [${xs.code.mkString(",")}]") } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ContentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ContentTests.scala new file mode 100644 index 000000000000..85a8db6c92a9 --- /dev/null +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ContentTests.scala @@ -0,0 +1,77 @@ +package io.joern.rubysrc2cpg.querying + +import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.semanticcpg.language.* + +class ContentTests extends RubyCode2CpgFixture(disableFileContent = false) { + "Content of file" in { + val fileContent = + """ + |class Animal + |end + | + |def foo + | puts "a" + |end + |""".stripMargin + + val cpg = code(fileContent, "Test0.rb") + + cpg.file.name("Test0.rb").content.head shouldBe fileContent + } + + "Content of method" in { + + val fooFunc = + """def foo + | puts "a" + |end""".stripMargin + + val cpg = code(s"""$fooFunc""".stripMargin) + + val method = cpg.method.name("foo").head + + method.content.head shouldBe fooFunc + } + + "Content of Class" in { + val cls = + """class Animal + |end""".stripMargin + + val cpg = code(s"""$cls""".stripMargin) + val animal = cpg.typeDecl.name("Animal").head + + animal.content.head shouldBe cls + } + + "Content of Module" in { + val mod = """module Foo + |end""".stripMargin + + val cpg = code(mod) + val module = cpg.typeDecl.name("Foo").head + + module.content.head shouldBe mod + } + + "Method and Class content" in { + val cls = + """class Animal + |end""".stripMargin + + val fooFunc = + """def foo + | puts "a" + |end""".stripMargin + + val cpg = code(s"""$cls + |$fooFunc""".stripMargin) + + val method = cpg.method.name("foo").head + val animal = cpg.typeDecl.name("Animal").head + + method.content.head shouldBe fooFunc + animal.content.head shouldBe cls + } +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index a24197513520..5c043380ba93 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -1,20 +1,21 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal} import io.shiftleft.semanticcpg.language.* class ControlStructureTests extends RubyCode2CpgFixture { "`while-end` statement is represented by a `WHILE` CONTROL_STRUCTURE node" in { val cpg = code(""" - |x = 1 - |while x > 0 do - | x = x - 1 - |end - |""".stripMargin) + |x = 1 + |while x > 0 do + | x = x - 1 + |end + |""".stripMargin) val List(whileNode) = cpg.whileBlock.l val List(whileCond) = whileNode.condition.isCall.l @@ -31,13 +32,13 @@ class ControlStructureTests extends RubyCode2CpgFixture { "begin-end-until should be lowered as a do-while loop" in { val cpg = code(""" - |i = 0 - |num = 5 - |begin - | num = i + 3 - |end until i < num - |puts num - |""".stripMargin) + |i = 0 + |num = 5 + |begin + | num = i + 3 + |end until i < num + |puts num + |""".stripMargin) val List(whileNode) = cpg.doBlock.l val List(whileCond) = whileNode.condition.isCall.l @@ -54,11 +55,11 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`until-end` statement is represented by a negated `WHILE` CONTROL_STRUCTURE node" in { val cpg = code(""" - |x = 1 - |until x <= 0 do - | x = x - 1 - |end - |""".stripMargin) + |x = 1 + |until x <= 0 do + | x = x - 1 + |end + |""".stripMargin) val List(untilNode) = cpg.whileBlock.l val List(untilNegCond) = untilNode.condition.isCall.l @@ -78,29 +79,32 @@ class ControlStructureTests extends RubyCode2CpgFixture { "a break expression nested in a control structure should be represented" in { val cpg = code(""" - |x = 0 - |num = -1 - |loop do - | num = x + 1 - | x = x + 1 - | if x > 10 - | break - | end - |end - |puts num - |""".stripMargin) + |x = 0 + |num = -1 + |loop do + | num = x + 1 + | x = x + 1 + | if x > 10 + | break + | end + |end + |puts num + |""".stripMargin) val List(breakNode) = cpg.break.l breakNode.code shouldBe "break" breakNode.lineNumber shouldBe Some(8) + // todo: investigate + // `loop` is lowered as a do-while loop with a true condition + cpg.controlStructure.condition("true").size shouldBe 1 } "`if-end` statement is represented by an `IF` CONTROL_STRUCTURE node" in { val cpg = code(""" - |if __LINE__ > 1 then - | "> 1" - |end - |""".stripMargin) + |if __LINE__ > 1 then + | "> 1" + |end + |""".stripMargin) val List(ifNode) = cpg.ifBlock.l val List(ifCond) = ifNode.condition.isCall.l @@ -116,12 +120,12 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`if-else-end` statement is represented by `IF`-`ELSE` CONTROL_STRUCTURE nodes" in { val cpg = code(""" - |if __LINE__ > 1 then - | "> 1" - |else - | "<= 1" - |end - |""".stripMargin) + |if __LINE__ > 1 then + | "> 1" + |else + | "<= 1" + |end + |""".stripMargin) val List(ifNode) = cpg.ifBlock.l val List(ifCond) = ifNode.condition.isCall.l @@ -141,12 +145,12 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`if-elsif-end` statement is represented by `IF`-`ELSE`-`IF` CONTROL_STRUCTURE nodes" in { val cpg = code(""" - |if __LINE__ == 0 then - | '= 0' - |elsif __LINE__ > 0 then - | '> 0' - |end - |""".stripMargin) + |if __LINE__ == 0 then + | '= 0' + |elsif __LINE__ > 0 then + | '> 0' + |end + |""".stripMargin) val List(ifNode) = cpg.ifBlock.where(_.lineNumber(2)).l val List(ifCond) = ifNode.condition.isCall.l @@ -173,74 +177,60 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`unless-end` statement is represented by a negated `IF` CONTROL_STRUCTURE node" in { val cpg = code(""" - |unless __LINE__ == 0 then - | x = '!= 0' - |end - |""".stripMargin) - - val List(unlessNode) = cpg.ifBlock.l - val List(unlessNegCond) = unlessNode.condition.isCall.l - val List(assignment) = unlessNode.whenTrue.assignment.l - - unlessNode.whenFalse.isEmpty shouldBe true + |unless __LINE__ == 0 then + | x = '!= 0' + |end + |""".stripMargin) - unlessNegCond.methodFullName shouldBe Operators.logicalNot - unlessNegCond.code shouldBe "__LINE__ == 0" - unlessNegCond.lineNumber shouldBe Some(2) + val List(unlessNode) = cpg.ifBlock.l + val List(unlessCondition) = unlessNode.condition.isCall.l + val List(assignment) = unlessNode.whenFalse.assignment.l - val List(unlessOriginalCond) = unlessNegCond.argument.isCall.l - unlessOriginalCond.methodFullName shouldBe Operators.equals - unlessOriginalCond.code shouldBe "__LINE__ == 0" + unlessCondition.methodFullName shouldBe Operators.equals + unlessCondition.code shouldBe "__LINE__ == 0" + unlessCondition.lineNumber shouldBe Some(2) assignment.code shouldBe "x = '!= 0'" assignment.lineNumber shouldBe Some(3) } - "`unless-else-end` statement is represented by a negated `IF` CONTROL_STRUCTURE node" in { + "`unless-else-end` statement is represented by a `IF` CONTROL_STRUCTURE node" in { val cpg = code(""" - |unless __LINE__ == 0 then - | x = '!= 0' - |else - | x = '= 0' - |end - |""".stripMargin) + |unless __LINE__ == 0 then + | x = '!= 0' + |else + | x = '= 0' + |end + |""".stripMargin) val List(unlessNode) = cpg.ifBlock.l - val List(unlessNegCond) = unlessNode.condition.isCall.l + val List(unlessCond) = unlessNode.condition.isCall.l val List(thenAssignment) = unlessNode.whenTrue.assignment.l val List(elseAssignment) = unlessNode.whenFalse.assignment.l - unlessNegCond.methodFullName shouldBe Operators.logicalNot - unlessNegCond.code shouldBe "__LINE__ == 0" - unlessNegCond.lineNumber shouldBe Some(2) - - val List(unlessOriginalCond) = unlessNegCond.argument.isCall.l - unlessOriginalCond.methodFullName shouldBe Operators.equals - unlessOriginalCond.code shouldBe "__LINE__ == 0" + unlessCond.methodFullName shouldBe Operators.equals + unlessCond.code shouldBe "__LINE__ == 0" + unlessCond.lineNumber shouldBe Some(2) - thenAssignment.code shouldBe "x = '!= 0'" - thenAssignment.lineNumber shouldBe Some(3) + // Then and Else is inverted with UNLESS + thenAssignment.code shouldBe "x = '= 0'" + thenAssignment.lineNumber shouldBe Some(5) - elseAssignment.code shouldBe "x = '= 0'" - elseAssignment.lineNumber shouldBe Some(5) + elseAssignment.code shouldBe "x = '!= 0'" + elseAssignment.lineNumber shouldBe Some(3) } "`... unless ...` statement is represented by a negated `IF` CONTROL_STRUCTURE node" in { val cpg = code(""" - |42 unless false - |""".stripMargin) - - val List(unlessNode) = cpg.ifBlock.l - val List(unlessNegCond) = unlessNode.condition.isCall.l - val List(thenLiteral) = unlessNode.whenTrue.isBlock.astChildren.isLiteral.l + |42 unless false + |""".stripMargin) - unlessNegCond.methodFullName shouldBe Operators.logicalNot - unlessNegCond.code shouldBe "false" - unlessNegCond.lineNumber shouldBe Some(2) + val List(unlessNode) = cpg.ifBlock.l + val List(unlessCond) = unlessNode.condition.isLiteral.l + val List(thenLiteral) = unlessNode.whenFalse.isBlock.astChildren.isLiteral.l - val List(unlessOriginalCond) = unlessNegCond.argument.isLiteral.l - unlessOriginalCond.code shouldBe "false" - unlessOriginalCond.lineNumber shouldBe Some(2) + unlessCond.code shouldBe "false" + unlessCond.lineNumber shouldBe Some(2) thenLiteral.code shouldBe "42" thenLiteral.lineNumber shouldBe Some(2) @@ -248,17 +238,14 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`unless` binds tighter than `=`" in { val cpg = code(""" - |x = 1 unless false - |""".stripMargin) - - val List(unlessNode) = cpg.ifBlock.l - val List(unlessNegCond) = unlessNode.condition.isCall.l - val List(assignment) = unlessNode.whenTrue.assignment.l + |x = 1 unless false + |""".stripMargin) - unlessNode.whenFalse.isEmpty shouldBe true + val List(unlessNode) = cpg.ifBlock.l + val List(unlessCond) = unlessNode.condition.isLiteral.l + val List(assignment) = unlessNode.whenFalse.assignment.l - unlessNegCond.methodFullName shouldBe Operators.logicalNot - unlessNegCond.code shouldBe "false" + unlessCond.code shouldBe "false" assignment.code shouldBe "x = 1" assignment.lineNumber shouldBe Some(2) @@ -266,8 +253,8 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`... if ...` statement is represented by an `IF` CONTROL_STRUCTURE node" in { val cpg = code(""" - |"> 1" if __LINE__ > 1 - |""".stripMargin) + |"> 1" if __LINE__ > 1 + |""".stripMargin) val List(ifNode) = cpg.ifBlock.l val List(ifCond) = ifNode.condition.isCall.l @@ -283,8 +270,8 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`... while ...` statement is represented by a `WHILE` CONTROL_STRUCTURE node" in { val cpg = code(""" - |puts 'hi' while (true) - |""".stripMargin) + |puts 'hi' while (true) + |""".stripMargin) val List(whileNode) = cpg.whileBlock.l val List(whileCond) = whileNode.condition.isLiteral.l @@ -293,7 +280,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { whileCond.code shouldBe "true" whileCond.lineNumber shouldBe Some(2) - putsHi.methodFullName shouldBe s"$kernelPrefix:puts" + putsHi.methodFullName shouldBe s"$kernelPrefix.puts" putsHi.code shouldBe "puts 'hi'" putsHi.lineNumber shouldBe Some(2) } @@ -311,92 +298,97 @@ class ControlStructureTests extends RubyCode2CpgFixture { "`begin ... rescue ... end is represented by a `TRY` CONTROL_STRUCTURE node" in { val cpg = code(""" - |def test1 - | begin - | puts - | 1 - | rescue E1 => e - | puts - | 2 - | rescue E2 - | puts - | 3 - | rescue - | puts - | 4 - | else - | puts - | 5 - | ensure - | puts - | 6 - | end - |end - |""".stripMargin) - - val List(rescueNode) = cpg.method("test1").tryBlock.l - rescueNode.controlStructureType shouldBe ControlStructureTypes.TRY - val List(body, rescueBody1, rescueBody2, rescueBody3, elseBody, ensureBody) = rescueNode.astChildren.l - body.ast.isLiteral.code.l shouldBe List("1") - body.order shouldBe 1 + |def test1 + | begin + | puts + | 1 + | rescue E1 => e + | puts + | 2 + | rescue E2 + | puts + | 3 + | rescue + | puts + | 4 + | else + | puts + | 5 + | ensure + | puts + | 6 + | end + |end + |""".stripMargin) - rescueBody1.ast.isLiteral.code.l shouldBe List("2") - rescueBody1.order shouldBe 2 + inside(cpg.method("test1").controlStructure.l) { + case tryStruct :: rescue1Struct :: rescue2Struct :: rescue3Struct :: elseStruct :: ensureStruct :: Nil => + tryStruct.controlStructureType shouldBe ControlStructureTypes.TRY + val body = tryStruct.astChildren.head + body.ast.isLiteral.code.l shouldBe List("1") - rescueBody2.ast.isLiteral.code.l shouldBe List("3") - rescueBody2.order shouldBe 2 + rescue1Struct.controlStructureType shouldBe ControlStructureTypes.CATCH + rescue1Struct.ast.isLocal.code.l shouldBe List("e") + rescue1Struct.ast.isLiteral.code.l shouldBe List("2") - rescueBody3.ast.isLiteral.code.l shouldBe List("4") - rescueBody3.order shouldBe 2 + rescue2Struct.controlStructureType shouldBe ControlStructureTypes.CATCH + rescue2Struct.ast.isLiteral.code.l shouldBe List("3") - elseBody.ast.isLiteral.code.l shouldBe List("5") - elseBody.order shouldBe 2 + rescue3Struct.controlStructureType shouldBe ControlStructureTypes.CATCH + rescue3Struct.ast.isLiteral.code.l shouldBe List("4") - ensureBody.ast.isLiteral.code.l shouldBe List("6") - ensureBody.order shouldBe 3 + elseStruct.controlStructureType shouldBe ControlStructureTypes.ELSE + elseStruct.ast.isLiteral.code.l shouldBe List("5") + ensureStruct.controlStructureType shouldBe ControlStructureTypes.FINALLY + ensureStruct.ast.isLiteral.code.l shouldBe List("6") + case xs => fail(s"Expected 6 structures, got ${xs.code.mkString(",")}") + } } "`begin ... ensure ... end is represented by a `TRY` CONTROL_STRUCTURE node" in { val cpg = code(""" - |def test2 - | begin - | 1 - | ensure - | 2 - | end - |end - |""".stripMargin) - val List(rescueNode) = cpg.method("test2").tryBlock.l - rescueNode.controlStructureType shouldBe ControlStructureTypes.TRY - val List(body, defaultElseBody, ensureBody) = rescueNode.astChildren.l + |def test2 + | begin + | 1 + | ensure + | 2 + | end + |end + |""".stripMargin) + + inside(cpg.method("test2").controlStructure.l) { + case tryStruct :: defaultElseStruct :: ensureStruct :: Nil => + tryStruct.controlStructureType shouldBe ControlStructureTypes.TRY + val body = tryStruct.astChildren.head + body.ast.isLiteral.code.l shouldBe List("1") - body.ast.isLiteral.code.l shouldBe List("1") - body.order shouldBe 1 + defaultElseStruct.controlStructureType shouldBe ControlStructureTypes.ELSE + defaultElseStruct.ast.isLiteral.code.l shouldBe List("nil") - defaultElseBody.ast.isLiteral.code.l shouldBe List("nil") - ensureBody.order shouldBe 3 + ensureStruct.controlStructureType shouldBe ControlStructureTypes.FINALLY + ensureStruct.ast.isLiteral.code.l shouldBe List("2") - ensureBody.ast.isLiteral.code.l shouldBe List("2") - ensureBody.order shouldBe 3 + case xs => fail(s"Expected two structures, got ${xs.code.mkString(",")}") + } } "`for .. in` control structure" should { val cpg = code(""" - |def foo1 - | x = [1, 2, 3] - | for i in x do - | puts x - i - | end - |end - | - |def foo2 - | x = 3 - | for i in 1..x do - | puts x + i - | end - |end - |""".stripMargin) + |def foo1 + | x = [1, 2, 3] + | for i in x do + | puts x - i + | end + |end + | + |def foo2 + | x = 3 + | for i in 1..x do + | puts x + i + | end + |end + |""".stripMargin) "create a FOR control structure node with body with an array iterable" in { inside(cpg.method("foo1").controlStructure.l) { @@ -404,12 +396,25 @@ class ControlStructureTests extends RubyCode2CpgFixture { forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR inside(forEachNode.astChildren.l) { - case (iteratorNode: Identifier) :: (iterableNode: Identifier) :: (doBody: Block) :: Nil => - iteratorNode.code shouldBe "i" - iterableNode.code shouldBe "x" - // We use .ast as there will be an implicit return node here - doBody.ast.isCall.code.headOption shouldBe Option("puts x - i") - case _ => fail("No node for iterable found in `for-in` statement") + case (idxLocal: Local) :: (iVarLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil => + idxLocal.name shouldBe "_idx_" + idxLocal.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Integer) + + iVarLocal.name shouldBe "i" + + initAssign.code shouldBe "_idx_ = 0" + initAssign.name shouldBe Operators.assignment + initAssign.methodFullName shouldBe Operators.assignment + + cond.code shouldBe "_idx_ < x.length" + cond.name shouldBe Operators.lessThan + cond.methodFullName shouldBe Operators.lessThan + + update.code shouldBe "i = x[_idx_++]" + update.name shouldBe Operators.assignment + update.methodFullName shouldBe Operators.assignment + + case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]") } inside(forEachNode.astChildren.isBlock.l) { @@ -429,13 +434,25 @@ class ControlStructureTests extends RubyCode2CpgFixture { forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR inside(forEachNode.astChildren.l) { - case (iteratorNode: Identifier) :: (iterableNode: Call) :: (doBody: Block) :: Nil => - iteratorNode.code shouldBe "i" - iterableNode.code shouldBe "1..x" - iterableNode.name shouldBe Operators.range - // We use .ast as there will be an implicit return node here - doBody.ast.isCall.code.headOption shouldBe Option("puts x + i") - case _ => fail("Invalid `for-in` children nodes") + case (idxLocal: Local) :: (iVarLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil => + idxLocal.name shouldBe "_idx_" + idxLocal.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Integer) + + iVarLocal.name shouldBe "i" + + initAssign.code shouldBe "_idx_ = 0" + initAssign.name shouldBe Operators.assignment + initAssign.methodFullName shouldBe Operators.assignment + + cond.code shouldBe "_idx_ < 1..x.length" + cond.name shouldBe Operators.lessThan + cond.methodFullName shouldBe Operators.lessThan + + update.code shouldBe "i = 1..x[_idx_++]" + update.name shouldBe Operators.assignment + update.methodFullName shouldBe Operators.assignment + + case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]") } case _ => fail("No control structure node found for `for-in`.") @@ -445,8 +462,8 @@ class ControlStructureTests extends RubyCode2CpgFixture { "implicit if-elsif-else assignment" should { val cpg = code(""" - | a = if (y > 3) then 123 elsif(y < 6) then 2003 elsif(y < 10) then 982 else 456 end - |""".stripMargin) + | a = if (y > 3) then 123 elsif(y < 6) then 2003 elsif(y < 10) then 982 else 456 end + |""".stripMargin) "Create assignment operators for each branch" in { inside(cpg.call.name(Operators.assignment).l) { @@ -462,8 +479,8 @@ class ControlStructureTests extends RubyCode2CpgFixture { "implicit if assignment" should { val cpg = code(""" - | a = if(x > 4) then 123 end - |""".stripMargin) + | a = if(x > 4) then 123 end + |""".stripMargin) "create assignment operators for if and default else branch" in { inside(cpg.call.name(Operators.assignment).l) { @@ -478,16 +495,16 @@ class ControlStructureTests extends RubyCode2CpgFixture { "if-elsif-else in function with explicit return statements" should { val cpg = code(""" - | def foo(x, y) - | if x < 0 then - | return 0 - | elsif x == 0 then - | return x - | else - | return y - | end - |end - |""".stripMargin) + | def foo(x, y) + | if x < 0 then + | return 0 + | elsif x == 0 then + | return x + | else + | return y + | end + |end + |""".stripMargin) "Generate return nodes without unknown nodes" in { inside(cpg.method.name("foo").methodReturn.toReturn.l) { @@ -518,4 +535,167 @@ class ControlStructureTests extends RubyCode2CpgFixture { } } } + + "Generate continue node for next" in { + val cpg = code(""" + |for i in arr do + | next if i % 2 == 0 + |end + |""".stripMargin) + + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.CONTINUE).l) { + case nextControl :: Nil => + nextControl.code shouldBe "next" + case xs => fail(s"Expected next to be continue, got [${xs.code.mkString(",")}]") + } + } + + "A `raise` call with a string argument should generate a `throw` control structure with explicit `StandardError.new` call" in { + val cpg = code("raise 'Hello, world!'") + inside(cpg.controlStructure.l) { + case (ctrlStruct: ControlStructure) :: Nil => + ctrlStruct.code shouldBe "raise 'Hello, world!'" + ctrlStruct.controlStructureType shouldBe ControlStructureTypes.THROW + + val constructorBlock = ctrlStruct.astChildren.head.asInstanceOf[Block] + constructorBlock.ast.isCall.where(_.name(Operators.alloc)).nonEmpty shouldBe true + + val initialize = constructorBlock.ast.isCall.name(Defines.Initialize).head + initialize.code shouldBe "StandardError.new('Hello, world!')" + val helloWorld = initialize.argument(1).asInstanceOf[Literal] + helloWorld.code shouldBe "'Hello, world!'" + case xs => fail(s"Expected single `throw` call, got [${xs.code.mkString(",")}]") + } + } + + "A `raise` call with an explicit error argument should generate a `throw` control structure" in { + val cpg = code("raise ZeroDivisionError.new 'b should not be 0'") + inside(cpg.controlStructure.l) { + case (ctrlStruct: ControlStructure) :: Nil => + ctrlStruct.code shouldBe "raise ZeroDivisionError.new 'b should not be 0'" + ctrlStruct.controlStructureType shouldBe ControlStructureTypes.THROW + + val constructorBlock = ctrlStruct.astChildren.head.asInstanceOf[Block] + constructorBlock.ast.isCall.where(_.name(Operators.alloc)).nonEmpty shouldBe true + + val initialize = constructorBlock.ast.isCall.name(Defines.Initialize).head + initialize.code shouldBe "ZeroDivisionError.new 'b should not be 0'" + val errMsg = initialize.argument(1).asInstanceOf[Literal] + errMsg.code shouldBe "'b should not be 0'" + case xs => fail(s"Expected single `throw` call, got [${xs.code.mkString(",")}]") + } + } + + "Ternary if" in { + val cpg = code(""" + |class Api::V1::UsersController < ApplicationController + | def index + | respond_with @user.admin ? User.all : @user + | end + |end + |""".stripMargin) + + inside(cpg.method.name("index").l) { + case indexMethod :: Nil => + inside(indexMethod.call.name(Operators.conditional).l) { + case ternary :: Nil => + ternary.code shouldBe "@user.admin ? User.all : @user" + + inside(ternary.argument.l) { + case condition :: (leftOpt: Block) :: (rightOpt: Block) :: Nil => + condition.code shouldBe "( = @user).admin" + condition.ast.isFieldIdentifier.code.l shouldBe List("@user", "admin") + + leftOpt.ast.fieldAccess.code.head shouldBe "User.all" + leftOpt.ast.isFieldIdentifier.code.l shouldBe List("User", "all") + + rightOpt.ast.fieldAccess.code.head shouldBe "self.@user" + rightOpt.ast.isFieldIdentifier.code.head shouldBe "@user" + + case xs => fail(s"Expected two arguments, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected one call for ternary, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected one method, got ${xs.name.mkString(",")}") + } + } + + "RETURN keyword in logicalAndExpression" in { + val cpg = code(""" + |def foo + | if (a == 1 && return) + | puts a + | end + |end + |""".stripMargin) + + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + + val List(_: Call, returnCall: Return) = ifStruct.condition.isBlock.astChildren.isCall.argument.l: @unchecked + returnCall.code shouldBe "return" + + case xs => fail(s"Expected one control strucuture, got [${xs.code.mkString(",")}]") + } + } + + "RETURN keyword in logicalOrExpression" in { + val cpg = code(""" + |def foo + | if (a == 10 || return) + | puts a + | end + |end + |""".stripMargin) + + inside(cpg.method.name("foo").controlStructure.l) { + case orIfStruct :: Nil => + orIfStruct.controlStructureType shouldBe ControlStructureTypes.IF + + val List(_: Call, returnCall: Return) = orIfStruct.condition.isBlock.astChildren.isCall.argument.l: @unchecked + returnCall.code shouldBe "return" + case xs => fail(s"Expected one IF structure, got [${xs.code.mkString(",")}]") + } + } + + "ForEach loops" in { + val cpg = code(""" + |fibNumbers = [0, 1, 1, 2, 3, 5, 8, 13] + |for num in fibNumbers + | puts num + |end + |""".stripMargin) + + inside(cpg.method.isModule.controlStructure.l) { + case forEachNode :: Nil => + forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR + + inside(forEachNode.astChildren.l) { + case (idxLocal: Local) :: (numLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil => + idxLocal.name shouldBe "_idx_" + idxLocal.typeFullName shouldBe Defines.prefixAsCoreType(Defines.Integer) + + numLocal.name shouldBe "num" + + initAssign.code shouldBe "_idx_ = 0" + initAssign.name shouldBe Operators.assignment + initAssign.methodFullName shouldBe Operators.assignment + + cond.code shouldBe "_idx_ < fibNumbers.length" + cond.name shouldBe Operators.lessThan + cond.methodFullName shouldBe Operators.lessThan + + update.code shouldBe "num = fibNumbers[_idx_++]" + update.name shouldBe Operators.assignment + update.methodFullName shouldBe Operators.assignment + + val List(putsCall) = cpg.call.nameExact("puts").l + putsCall.astParent shouldBe forBlock + + case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one node for `forEach` loop, got [${xs.code.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala index 106efaf63b19..d17c74f17ed0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala @@ -1,8 +1,9 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines.Main +import io.joern.rubysrc2cpg.passes.{DependencyPass, Defines as RubyDefines} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines -import io.joern.rubysrc2cpg.passes.Defines as RubyDefines import io.shiftleft.codepropertygraph.generated.nodes.{Block, Identifier} import io.shiftleft.semanticcpg.language.* @@ -13,7 +14,7 @@ class DependencyTests extends RubyCode2CpgFixture { val cpg = code(DependencyTests.GEMFILELOCK, "Gemfile.lock") "result in dependency nodes of the set packages" in { - inside(cpg.dependency.nameNot(RubyDefines.Resolver).l) { + inside(cpg.dependency.nameNot(RubyDefines.Resolver).versionNot(DependencyPass.CORE_GEM_VERSION).l) { case aruba :: bcrypt :: betterErrors :: Nil => aruba.name shouldBe "aruba" aruba.version shouldBe "0.14.12" @@ -35,7 +36,7 @@ class DependencyTests extends RubyCode2CpgFixture { val cpg = code(DependencyTests.GEMFILE, "Gemfile") "result in dependency nodes of the set packages" in { - inside(cpg.dependency.nameNot(RubyDefines.Resolver).l) { + inside(cpg.dependency.nameNot(RubyDefines.Resolver).versionNot(DependencyPass.CORE_GEM_VERSION).l) { case aruba :: bcrypt :: coffeeRails :: Nil => aruba.name shouldBe "aruba" aruba.version shouldBe "2.5.1" @@ -58,7 +59,10 @@ class DependencyTests extends RubyCode2CpgFixture { "be preferred over a normal Gemfile" in { // Our Gemfile.lock specifies exact versions whereas the Gemfile does not - cpg.dependency.nameNot(RubyDefines.Resolver).forall(d => !d.version.isBlank) shouldBe true + cpg.dependency + .nameNot(RubyDefines.Resolver) + .versionNot(DependencyPass.CORE_GEM_VERSION) + .forall(d => !d.version.isBlank) shouldBe true } } @@ -94,9 +98,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies = case (v: Identifier) :: (block: Block) :: Nil => v.dynamicTypeHintFullName should contain("dummy_logger.Main_module.Main_outer_class") - inside(block.astChildren.isCall.nameExact("new").headOption) { + inside(block.astChildren.isCall.nameExact(RubyDefines.Initialize).headOption) { case Some(constructorCall) => - constructorCall.methodFullName shouldBe s"dummy_logger.Main_module.Main_outer_class:${RubyDefines.Initialize}" + constructorCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName case None => fail(s"Expected constructor call, did not find one") } case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]") @@ -108,9 +112,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies = case (g: Identifier) :: (block: Block) :: Nil => g.dynamicTypeHintFullName should contain("dummy_logger.Help") - inside(block.astChildren.isCall.name("new").headOption) { + inside(block.astChildren.isCall.name(RubyDefines.Initialize).headOption) { case Some(constructorCall) => - constructorCall.methodFullName shouldBe s"dummy_logger.Help:${RubyDefines.Initialize}" + constructorCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName case None => fail(s"Expected constructor call, did not find one") } case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]") @@ -120,12 +124,12 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies = // TODO: This requires type propagation "recognise methodFullName for `first_fun`" ignore { cpg.call.name("first_fun").head.methodFullName should equal( - "dummy_logger::program:Main_module:Main_outer_class:first_fun" + s"dummy_logger.$Main.Main_module.Main_outer_class.first_fun" ) cpg.call .name("help_print") .head - .methodFullName shouldBe "dummy_logger::program:Help:help_print" + .methodFullName shouldBe s"dummy_logger.$Main:Help:help_print" } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DestructuredAssignmentsTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DestructuredAssignmentsTests.scala index 6e71c4a963f8..883917f0119e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DestructuredAssignmentsTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DestructuredAssignmentsTests.scala @@ -2,8 +2,8 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.passes.Defines.RubyOperators import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.semanticcpg.language.* class DestructuredAssignmentsTests extends RubyCode2CpgFixture { @@ -77,8 +77,8 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |a, b, *c = 1, 2, 3, 4 |""".stripMargin) - inside(cpg.assignment.l) { - case aAssignment :: bAssignment :: cAssignment :: Nil => + inside(cpg.assignment.codeExact("a, b, *c = 1, 2, 3, 4").l) { + case aAssignment :: bAssignment :: cAssignment :: _ :: Nil => aAssignment.code shouldBe "a, b, *c = 1, 2, 3, 4" bAssignment.code shouldBe "a, b, *c = 1, 2, 3, 4" cAssignment.code shouldBe "a, b, *c = 1, 2, 3, 4" @@ -91,10 +91,12 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { b.name shouldBe "b" two.code shouldBe "2" - val List(c: Identifier, arr: Call) = cAssignment.argumentOut.toList: @unchecked + val List(c: Identifier, arr: Block) = cAssignment.argumentOut.toList: @unchecked c.name shouldBe "c" - arr.name shouldBe Operators.arrayInitializer - inside(arr.argumentOut.l) { + arr.code shouldBe "a, b, *c = 1, 2, 3, 4" + + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { case (three: Literal) :: (four: Literal) :: Nil => three.code shouldBe "3" four.code shouldBe "4" @@ -111,8 +113,8 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |a, *b, c = 1, 2, 3, 4 |""".stripMargin) - inside(cpg.assignment.l) { - case aAssignment :: bAssignment :: cAssignment :: Nil => + inside(cpg.assignment.codeExact("a, *b, c = 1, 2, 3, 4").l) { + case aAssignment :: bAssignment :: _ :: cAssignment :: Nil => aAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4" bAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4" cAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4" @@ -121,14 +123,16 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { a.name shouldBe "a" one.code shouldBe "1" - val List(b: Identifier, arr: Call) = bAssignment.argumentOut.toList: @unchecked + val List(b: Identifier, arr: Block) = bAssignment.argumentOut.toList: @unchecked b.name shouldBe "b" - arr.name shouldBe Operators.arrayInitializer - inside(arr.argumentOut.l) { + arr.code shouldBe "a, *b, c = 1, 2, 3, 4" + + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { case (two: Literal) :: (three: Literal) :: Nil => two.code shouldBe "2" three.code shouldBe "3" - case _ => fail("Unexpected number of array elements in `b`'s assignment") + case _ => fail("Unexpected number of array elements in `c`'s assignment") } val List(c: Identifier, four: Literal) = cAssignment.argumentOut.toList: @unchecked @@ -145,20 +149,22 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |*a, b, c = 1, 2, 3, 4 |""".stripMargin) - inside(cpg.assignment.l) { - case aAssignment :: bAssignment :: cAssignment :: Nil => + inside(cpg.assignment.codeExact("*a, b, c = 1, 2, 3, 4").l) { + case aAssignment :: _ :: bAssignment :: cAssignment :: Nil => aAssignment.code shouldBe "*a, b, c = 1, 2, 3, 4" bAssignment.code shouldBe "*a, b, c = 1, 2, 3, 4" cAssignment.code shouldBe "*a, b, c = 1, 2, 3, 4" - val List(a: Identifier, arr: Call) = aAssignment.argumentOut.toList: @unchecked + val List(a: Identifier, arr: Block) = aAssignment.argumentOut.toList: @unchecked a.name shouldBe "a" - arr.name shouldBe Operators.arrayInitializer - inside(arr.argumentOut.l) { + arr.code shouldBe "*a, b, c = 1, 2, 3, 4" + + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { case (one: Literal) :: (two: Literal) :: Nil => one.code shouldBe "1" two.code shouldBe "2" - case _ => fail("Unexpected number of array elements in `a`'s assignment") + case _ => fail("Unexpected number of array elements in `c`'s assignment") } val List(b: Identifier, three: Literal) = bAssignment.argumentOut.toList: @unchecked @@ -179,14 +185,15 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |a = 1, 2, 3, 4 |""".stripMargin) - inside(cpg.assignment.l) { + inside(cpg.assignment.codeExact("a = 1, 2, 3, 4").l) { case aAssignment :: Nil => aAssignment.code shouldBe "a = 1, 2, 3, 4" - val List(a: Identifier, arr: Call) = aAssignment.argumentOut.toList: @unchecked + val List(a: Identifier, arr: Block) = aAssignment.argumentOut.toList: @unchecked a.name shouldBe "a" - arr.name shouldBe Operators.arrayInitializer - inside(arr.argumentOut.l) { + arr.code shouldBe "1, 2, 3, 4" + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { case (one: Literal) :: (two: Literal) :: (three: Literal) :: (four: Literal) :: Nil => one.code shouldBe "1" two.code shouldBe "2" @@ -210,7 +217,7 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |a, b, c = 1, 2, *list |""".stripMargin) - inside(cpg.assignment.l) { + inside(cpg.assignment.codeExact("a, b, c = 1, 2, *list", "list = [3, 4]").l) { case listAssignment :: aAssignment :: bAssignment :: cAssignment :: Nil => listAssignment.code shouldBe "list = [3, 4]" aAssignment.code shouldBe "a, b, c = 1, 2, *list" @@ -245,7 +252,7 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { |a, b, c = 1, *list |""".stripMargin) - inside(cpg.assignment.l) { + inside(cpg.assignment.codeExact("list = [3, 4]", "a, b, c = 1, *list").l) { case listAssignment :: aAssignment :: bAssignment :: cAssignment :: Nil => listAssignment.code shouldBe "list = [3, 4]" aAssignment.code shouldBe "a, b, c = 1, *list" @@ -281,4 +288,130 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture { } + "Destructured Assignment with splat in the middle" in { + val cpg = code(""" + |a, *b, c = 1, 2, 3, 4, 5, 6 + |""".stripMargin) + + inside(cpg.assignment.codeExact("a, *b, c = 1, 2, 3, 4, 5, 6").l) { + case aAssignment :: bAssignment :: _ :: cAssignment :: Nil => + aAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4, 5, 6" + bAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4, 5, 6" + cAssignment.code shouldBe "a, *b, c = 1, 2, 3, 4, 5, 6" + + val List(a: Identifier, lit: Literal) = aAssignment.argumentOut.toList: @unchecked + a.name shouldBe "a" + lit.code shouldBe "1" + + val List(splat: Identifier, arr: Block) = bAssignment.argumentOut.toList: @unchecked + splat.name shouldBe "b" + arr.code shouldBe "a, *b, c = 1, 2, 3, 4, 5, 6" + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { + case (two: Literal) :: (three: Literal) :: (four: Literal) :: (five: Literal) :: Nil => + two.code shouldBe "2" + three.code shouldBe "3" + four.code shouldBe "4" + five.code shouldBe "5" + case _ => fail("Unexpected number of array elements in `*`'s assignment") + } + + val List(c: Identifier, cLiteral: Literal) = cAssignment.argumentOut.toList: @unchecked + c.name shouldBe "c" + cLiteral.code shouldBe "6" + case xs => fail(s"Expected three assignments, got ${xs.code.mkString(",")}") + } + } + + "Destructured assignment with naked splat" in { + val cpg = code(""" + |*, a = 1, 2, 3 + |""".stripMargin) + + inside(cpg.assignment.codeExact("*, a = 1, 2, 3").l) { + case splatAssignment :: _ :: aAssignment :: Nil => + aAssignment.code shouldBe "*, a = 1, 2, 3" + splatAssignment.code shouldBe "*, a = 1, 2, 3" + + val List(a: Identifier, lit: Literal) = aAssignment.argumentOut.toList: @unchecked + a.name shouldBe "a" + lit.code shouldBe "3" + + val List(splat: Identifier, arr: Block) = splatAssignment.argumentOut.toList: @unchecked + splat.name shouldBe "_" + arr.code shouldBe "*, a = 1, 2, 3" + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { + case (one: Literal) :: (two: Literal) :: Nil => + one.code shouldBe "1" + two.code shouldBe "2" + case _ => fail("Unexpected number of array elements in `*`'s assignment") + } + case _ => fail("Unexpected number of assignments found") + } + } + + "Destructured Assignment RHS" in { + val cpg = code(""" + |a, *b, c = 1, 2, *d, *f, 4 + |""".stripMargin) + + inside(cpg.assignment.codeExact("a, *b, c = 1, 2, *d, *f, 4").l) { + case aAssignment :: bAssignment :: _ :: cAssignment :: Nil => + aAssignment.code shouldBe "a, *b, c = 1, 2, *d, *f, 4" + bAssignment.code shouldBe "a, *b, c = 1, 2, *d, *f, 4" + cAssignment.code shouldBe "a, *b, c = 1, 2, *d, *f, 4" + + val List(a: Identifier, aLiteral: Literal) = aAssignment.argumentOut.toList: @unchecked + a.name shouldBe "a" + aLiteral.code shouldBe "1" + + val List(splat: Identifier, arr: Block) = bAssignment.argumentOut.toList: @unchecked + splat.name shouldBe "b" + arr.code shouldBe "a, *b, c = 1, 2, *d, *f, 4" + + val asgns = arr.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { + case (two: Literal) :: (d: Call) :: (f: Call) :: Nil => + two.code shouldBe "2" + + d.code shouldBe "*d" + d.methodFullName shouldBe RubyOperators.splat + + f.code shouldBe "*f" + f.methodFullName shouldBe RubyOperators.splat + + case xs => fail(s"Unexpected number of array elements in `*`'s assignment, got ${xs.code.mkString(",")}") + } + + val List(c: Identifier, cLiteral: Literal) = cAssignment.argumentOut.toList: @unchecked + c.name shouldBe "c" + cLiteral.code shouldBe "4" + + case xs => fail(s"Expected 3 assignments, got ${xs.code.mkString(",")}") + } + } + + "multi-assignments as a return value" should { + + val cpg = code(""" + |def f + | a, b = 1, 2 # => return [1, 2] + |end + |""".stripMargin) + + "create an explicit return of the LHS values as an array" in { + val arrayLiteral = cpg.method.name("f").methodReturn.toReturn.astChildren.isBlock.head + + arrayLiteral.code shouldBe "a, b = 1, 2" + + val asgns = arrayLiteral.astChildren.assignment.where(_.target.isCall.nameExact(Operators.indexAccess)).l + inside(asgns.map(_.source)) { case (a: Identifier) :: (b: Identifier) :: Nil => + a.code shouldBe "a" + b.code shouldBe "b" + } + } + + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala index 4b89a14b3c87..90f4f3b3f40a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala @@ -1,10 +1,12 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.GlobalTypes.builtinPrefix +import io.joern.rubysrc2cpg.passes.Defines.{Initialize, Main, RubyOperators} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* +import io.joern.rubysrc2cpg.passes.Defines as RubyDefines class DoBlockTests extends RubyCode2CpgFixture { @@ -21,7 +23,7 @@ class DoBlockTests extends RubyCode2CpgFixture { | |""".stripMargin) - "create an anonymous method with associated type declaration" in { + "create an anonymous method with associated type declaration and wrapper type" in { inside(cpg.method.isModule.l) { case program :: Nil => inside(program.astChildren.collectAll[Method].l) { @@ -29,18 +31,22 @@ class DoBlockTests extends RubyCode2CpgFixture { foo.name shouldBe "foo" closureMethod.name shouldBe "0" - closureMethod.fullName shouldBe "Test0.rb:::program:0" + closureMethod.fullName shouldBe s"Test0.rb:$Main.0" case xs => fail(s"Expected a two method nodes, instead got [${xs.code.mkString(", ")}]") } inside(program.astChildren.collectAll[TypeDecl].isLambda.l) { case closureType :: Nil => closureType.name shouldBe "0" - closureType.fullName shouldBe "Test0.rb:::program:0" + closureType.fullName shouldBe s"Test0.rb:$Main.0" + case xs => fail(s"Expected a one closure type node, instead got [${xs.code.mkString(", ")}]") + } + inside(program.astChildren.collectAll[TypeDecl].name(".*Proc").l) { + case closureType :: Nil => val callMember = closureType.member.nameExact("call").head callMember.typeFullName shouldBe Defines.Any - callMember.dynamicTypeHintFullName shouldBe Seq("Test0.rb:::program:0") + callMember.dynamicTypeHintFullName shouldBe Seq(s"Test0.rb:$Main.0") case xs => fail(s"Expected a one closure type node, instead got [${xs.code.mkString(", ")}]") } case xs => fail(s"Expected a single program module, instead got [${xs.code.mkString(", ")}]") @@ -48,9 +54,9 @@ class DoBlockTests extends RubyCode2CpgFixture { } "create a method ref argument with populated type full name, which corresponds to the method type" in { - val methodRefArg = cpg.call("foo").argument(1).head.asInstanceOf[MethodRef] + val typeRefArg = cpg.call("foo").argument(1).head.asInstanceOf[TypeRef] val lambdaTypeDecl = cpg.typeDecl("0").head - methodRefArg.typeFullName shouldBe lambdaTypeDecl.fullName + typeRefArg.typeFullName shouldBe s"${lambdaTypeDecl.fullName}&Proc" } "have no parameters in the closure declaration" in { @@ -79,19 +85,19 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) "create an anonymous method with associated type declaration" in { - inside(cpg.method.nameExact(":program").l) { + inside(cpg.method.isModule.l) { case program :: Nil => inside(program.astChildren.collectAll[Method].l) { case closureMethod :: Nil => closureMethod.name shouldBe "0" - closureMethod.fullName shouldBe "Test0.rb:::program:0" + closureMethod.fullName shouldBe s"Test0.rb:$Main.0" case xs => fail(s"Expected a one method nodes, instead got [${xs.code.mkString(", ")}]") } - inside(program.astChildren.collectAll[TypeDecl].l) { + inside(program.astChildren.collectAll[TypeDecl].isLambda.l) { case closureType :: Nil => closureType.name shouldBe "0" - closureType.fullName shouldBe "Test0.rb:::program:0" + closureType.fullName shouldBe s"Test0.rb:$Main.0" case xs => fail(s"Expected a one closure type node, instead got [${xs.code.mkString(", ")}]") } case xs => fail(s"Expected a single program module, instead got [${xs.code.mkString(", ")}]") @@ -108,13 +114,13 @@ class DoBlockTests extends RubyCode2CpgFixture { "specify the closure reference as an argument to the member call with block" in { inside(cpg.call("each").argument.l) { - case (myArray: Identifier) :: (lambdaRef: MethodRef) :: Nil => + case (myArray: Identifier) :: (lambdaRef: TypeRef) :: Nil => myArray.argumentIndex shouldBe 0 myArray.name shouldBe "my_array" myArray.code shouldBe "my_array" lambdaRef.argumentIndex shouldBe 1 - lambdaRef.methodFullName shouldBe "Test0.rb:::program:0" + lambdaRef.typeFullName shouldBe s"Test0.rb:$Main.0&Proc" case xs => fail(s"Expected `each` call to have call and method ref arguments, instead got [${xs.code.mkString(", ")}]") } @@ -141,20 +147,20 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) "create an anonymous method with associated type declaration" in { - inside(cpg.method.nameExact(":program").l) { + inside(cpg.method.isModule.l) { case program :: Nil => inside(program.astChildren.collectAll[Method].l) { case closureMethod :: Nil => closureMethod.name shouldBe "0" - closureMethod.fullName shouldBe "Test0.rb:::program:0" + closureMethod.fullName shouldBe s"Test0.rb:$Main.0" closureMethod.isLambda.nonEmpty shouldBe true case xs => fail(s"Expected a one method nodes, instead got [${xs.code.mkString(", ")}]") } - inside(program.astChildren.collectAll[TypeDecl].l) { + inside(program.astChildren.collectAll[TypeDecl].isLambda.l) { case closureType :: Nil => closureType.name shouldBe "0" - closureType.fullName shouldBe "Test0.rb:::program:0" + closureType.fullName shouldBe s"Test0.rb:$Main.0" closureType.isLambda.nonEmpty shouldBe true case xs => fail(s"Expected a one closure type node, instead got [${xs.code.mkString(", ")}]") } @@ -173,13 +179,13 @@ class DoBlockTests extends RubyCode2CpgFixture { "specify the closure reference as an argument to the member call with block" in { inside(cpg.call("each").argument.l) { - case (hash: Identifier) :: (lambdaRef: MethodRef) :: Nil => + case (hash: Identifier) :: (lambdaRef: TypeRef) :: Nil => hash.argumentIndex shouldBe 0 hash.name shouldBe "hash" hash.code shouldBe "hash" lambdaRef.argumentIndex shouldBe 1 - lambdaRef.methodFullName shouldBe "Test0.rb:::program:0" + lambdaRef.typeFullName shouldBe s"Test0.rb:$Main.0&Proc" case xs => fail(s"Expected `each` call to have call and method ref arguments, instead got [${xs.code.mkString(", ")}]") } @@ -207,14 +213,16 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) // Basic assertions for expected behaviour - "create the declarations for the closure" in { - inside(cpg.method(".*").l) { + "create the declarations for the closure with captured local" in { + inside(cpg.method.isLambda.l) { case m :: Nil => m.name should startWith("") + val myValue = m.local.nameExact("myValue").head + myValue.closureBindingId shouldBe Option(s"Test0.rb:$Main.myValue") case xs => fail(s"Expected exactly one closure method decl, instead got [${xs.code.mkString(",")}]") } - inside(cpg.typeDecl(".*").l) { + inside(cpg.typeDecl.isLambda.l) { case m :: Nil => m.name should startWith("") case xs => fail(s"Expected exactly one closure type decl, instead got [${xs.code.mkString(",")}]") @@ -224,17 +232,17 @@ class DoBlockTests extends RubyCode2CpgFixture { "annotate the nodes via CAPTURE bindings" in { cpg.all.collectAll[ClosureBinding].l match { case myValue :: Nil => - myValue.closureOriginalName.head shouldBe "myValue" + myValue.closureOriginalName shouldBe Option("myValue") inside(myValue._localViaRefOut) { case Some(local) => local.name shouldBe "myValue" - local.method.fullName.headOption shouldBe Option("Test0.rb:::program") + local.method.fullName.headOption shouldBe Option(s"Test0.rb:$Main") case None => fail("Expected closure binding refer to the captured local") } inside(myValue._captureIn.l) { - case (x: MethodRef) :: Nil => x.methodFullName shouldBe "Test0.rb:::program:0" - case xs => fail(s"Expected single method ref binding but got [${xs.mkString(",")}]") + case (x: TypeRef) :: Nil => x.typeFullName shouldBe s"Test0.rb:$Main.0&Proc" + case xs => fail(s"Expected single method ref binding but got [${xs.mkString(",")}]") } case xs => @@ -260,15 +268,16 @@ class DoBlockTests extends RubyCode2CpgFixture { inside(constrBlock.astChildren.l) { case (tmpLocal: Local) :: (tmpAssign: Call) :: (newCall: Call) :: (_: Identifier) :: Nil => tmpLocal.name shouldBe "" - tmpAssign.code shouldBe " = Array.new(x) { |i| i += 1 }" + tmpAssign.code shouldBe s" = Array.$Initialize" - newCall.name shouldBe "new" - newCall.methodFullName shouldBe s"$builtinPrefix.Array:initialize" + newCall.name shouldBe Initialize + newCall.methodFullName shouldBe Defines.DynamicCallUnknownFullName + newCall.dynamicTypeHintFullName should contain(s"${RubyDefines.prefixAsCoreType(s"Array")}.$Initialize") inside(newCall.argument.l) { - case (_: Identifier) :: (x: Identifier) :: (closure: MethodRef) :: Nil => + case (_: Identifier) :: (x: Identifier) :: (closure: TypeRef) :: Nil => x.name shouldBe "x" - closure.methodFullName should endWith("0") + closure.typeFullName should endWith("0&Proc") case xs => fail(s"Expected a base, `x`, and closure ref, instead got [${xs.code.mkString(",")}]") } case xs => @@ -276,7 +285,8 @@ class DoBlockTests extends RubyCode2CpgFixture { s"Expected four nodes under the lowering block of a constructor, instead got [${xs.code.mkString(",")}]" ) } - case xs => fail(s"Unexpected `foo` assignment children [${xs.code.mkString(",")}]") + case xs => + fail(s"Unexpected `foo` assignment children [${xs.code.mkString(",")}]") } } } @@ -308,13 +318,228 @@ class DoBlockTests extends RubyCode2CpgFixture { "create a call `test_name` with a test name and lambda argument" in { inside(cpg.call.nameExact("test_name").argument.l) { - case (_: Identifier) :: (testName: Literal) :: (testMethod: MethodRef) :: Nil => + case (_: Identifier) :: (testName: Literal) :: (testMethod: TypeRef) :: Nil => testName.code shouldBe "'Foo'" - testMethod.referencedMethod.call.nameExact("puts").nonEmpty shouldBe true + cpg.method + .fullNameExact(testMethod.typ.referencedTypeDecl.member.name("call").dynamicTypeHintFullName.toSeq*) + .call + .nameExact("puts") + .nonEmpty shouldBe true case xs => fail(s"Expected a literal and method ref argument, instead got $xs") } } } + "A lambda with arrow syntax" should { + + val cpg = code(""" + |arrow_lambda = ->(y) { y } + |""".stripMargin) + + "create a lambda method with a `y` parameter" in { + inside(cpg.method.isLambda.headOption) { + case Some(lambda) => + lambda.code shouldBe "->(y) { y }" + lambda.parameter.name.l shouldBe List("self", "y") + case xs => fail(s"Expected a lambda method") + } + } + + "create a method ref assigned to `arrow_lambda`" in { + inside(cpg.method.isModule.assignment.code("arrow_lambda.*").headOption) { + case Some(lambdaAssign) => + lambdaAssign.target.asInstanceOf[Identifier].name shouldBe "arrow_lambda" + lambdaAssign.source.asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.0&Proc" + case xs => fail(s"Expected an assignment to a lambda") + } + } + + } + + "A lambda with lambda keyword syntax" should { + + val cpg = code(""" + |a_lambda = lambda { |y| y } + |""".stripMargin) + + "create a lambda method with a `y` parameter" in { + inside(cpg.method.isLambda.headOption) { + case Some(lambda) => + lambda.code shouldBe "lambda { |y| y }" + lambda.parameter.name.l shouldBe List("self", "y") + case xs => fail(s"Expected a lambda method") + } + } + + "create a method ref assigned to `arrow_lambda`" in { + inside(cpg.method.isModule.assignment.code("a_lambda.*").headOption) { + case Some(lambdaAssign) => + lambdaAssign.target.asInstanceOf[Identifier].name shouldBe "a_lambda" + lambdaAssign.source.asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.0&Proc" + case xs => fail(s"Expected an assignment to a lambda") + } + } + + } + + "One local node for variable in lambda only" in { + val cpg = code(""" + | def get_pto_schedule + | begin + | jfs = [] + | schedules = [] + | schedules.each do |s| + | hash = Hash.new + | hash[:id] = s[:id] + | hash[:title] = s[:event_name] + | hash[:start] = s[:date_begin] + | hash[:end] = s[:date_end] + | jfs << hash + | end + | rescue + | end + | end + |""".stripMargin) + + inside(cpg.local.nameNot("").l) { + case jfsOutsideLocal :: schedules :: hashInsideLocal :: jfsCapturedLocal :: Nil => + jfsOutsideLocal.closureBindingId shouldBe None + hashInsideLocal.closureBindingId shouldBe None + jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") + case xs => fail(s"Expected 6 locals, got ${xs.code.mkString(",")}") + } + + inside(cpg.method.isLambda.local.l) { + case hashLocal :: _ :: jfsLocal :: Nil => + hashLocal.closureBindingId shouldBe None + jfsLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") + case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}") + } + } + + "Various do-block parameters" should { + val cpg = code(""" + |f { |a, (b, c), *d, e, (f, *g), **h, &i| + | puts a + |} + |""".stripMargin) + + "Generate correct parameters" in { + inside(cpg.method.isLambda.parameter.l) { + case _ :: aParam :: tmp0Param :: dParam :: eParam :: tmp1Param :: hParam :: iParam :: Nil => + aParam.name shouldBe "a" + aParam.code shouldBe "a" + + tmp0Param.name shouldBe "" + tmp0Param.code shouldBe "" + + dParam.name shouldBe "d" + dParam.code shouldBe "*d" + + eParam.name shouldBe "e" + eParam.code shouldBe "e" + + tmp1Param.name shouldBe "" + tmp1Param.code shouldBe "" + + hParam.name shouldBe "h" + hParam.code shouldBe "**h" + + iParam.name shouldBe "i" + iParam.code shouldBe "&i" + case xs => fail(s"Expected 8 parameters, got [${xs.name.mkString(", ")}]") + } + } + + "Generate required locals" in { + inside(cpg.method.isLambda.body.local.l) { + case bLocal :: cLocal :: fLocal :: gSplatLocal :: Nil => + bLocal.code shouldBe "b" + cLocal.code shouldBe "c" + + fLocal.code shouldBe "f" + gSplatLocal.code shouldBe "g" + case xs => fail(s"Expected 4 locals, got [${xs.name.mkString(", ")}]") + } + } + + "Generate required `assignment` calls" in { + inside(cpg.method.isLambda.call(Operators.assignment).l) { + case bAssign :: cAssign :: fAssign :: gAssign :: Nil => + bAssign.code shouldBe "b = *" + cAssign.code shouldBe "c = *" + + fAssign.code shouldBe "f = *" + gAssign.code shouldBe "*g = *" + case xs => fail(s"Expected 4 assignments, got [${xs.code.mkString(", ")}]") + } + } + + "Return nil and not the desugaring" in { + val nilLiteral = cpg.method.isLambda.methodReturn.toReturn.astChildren.isLiteral.head + nilLiteral.code shouldBe "return nil" + } + } + + "Nested grouped parameter in block" in { + val cpg = code(""" + |def format_result(result) + | result.each_with_object({}) do |((label_id, date), count), hash| + | label = labels_by_id.fetch(label_id) + | end.values + |end + |""".stripMargin) + + inside(cpg.method.isLambda.body.astChildren.isCall.name(Operators.assignment).l) { + case _ :: groupedParam :: countAssignment :: Nil => + inside(groupedParam.argument.l) { + case (labelIdAssign: Call) :: (dateAssign: Call) :: (tmp0Splat: Call) :: Nil => + inside(labelIdAssign.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.code shouldBe "label_id" + + rhs.code shouldBe "*" + rhs.methodFullName shouldBe RubyOperators.splat + case xs => + fail(s"Expected lhs and rhs for assignment, got ${xs.code.mkString(",")}") + } + + inside(dateAssign.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.code shouldBe "date" + + rhs.code shouldBe "*" + rhs.methodFullName shouldBe RubyOperators.splat + case xs => + fail(s"Expected lhs and rhs for assignment, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected four arguments for call, got [${xs.code.mkString(",")}]") + } + + inside(countAssignment.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.code shouldBe "count" + rhs.code shouldBe "*" + rhs.methodFullName shouldBe RubyOperators.splat + case xs => fail(s"Expected LHS and RHS for assignment, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected three assignment calls, got [${xs.code.mkString(",")}]") + } + } + + "a back reference in a do block should be a field access from `self`" in { + val cpg = code(""" + |def bar() + | foo("something") { urls << $& } + |end + |""".stripMargin) + val backRefCall = cpg.method.isLambda.ast.fieldAccess + .and(_.fieldIdentifier.canonicalNameExact("$&"), _.argument(1).isIdentifier.nameExact(RubyDefines.Self)) + .head + backRefCall.name shouldBe Operators.fieldAccess + backRefCall.code shouldBe "self.$&" + backRefCall.lineNumber shouldBe Option(3) + backRefCall.columnNumber shouldBe Option(29) + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala index 8f6608977daa..7e330caef591 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala @@ -1,5 +1,7 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.Main import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, TypeRef} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} @@ -7,41 +9,24 @@ import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends RubyCode2CpgFixture { - "`x.y` is represented by an `x.y` CALL without arguments" in { + "`x.y` is represented by a `x.y` field access" in { val cpg = code(""" - |x.y - |""".stripMargin) + |x = Foo.new + |x.y + |""".stripMargin) - inside(cpg.call("y").headOption) { + inside(cpg.fieldAccess.code("x.y").headOption) { case Some(xyCall) => - xyCall.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH - xyCall.lineNumber shouldBe Some(2) + xyCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + xyCall.name shouldBe Operators.fieldAccess + xyCall.methodFullName shouldBe Operators.fieldAccess + xyCall.lineNumber shouldBe Some(3) xyCall.code shouldBe "x.y" - - inside(xyCall.argumentOption(0)) { - case Some(receiver: Call) => - receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "self.x" - case _ => fail("Expected an field access receiver") - } - - inside(xyCall.receiver.headOption) { - case Some(xyBase: Call) => - xyBase.name shouldBe Operators.fieldAccess - xyBase.code shouldBe "x.y" - - val selfX = xyBase.argument(1).asInstanceOf[Call] - selfX.code shouldBe "self.x" - - val yIdentifier = xyBase.argument(2).asInstanceOf[FieldIdentifier] - yIdentifier.code shouldBe "y" - case _ => fail("Expected an field access receiver") - } - case None => fail("Expected a call with the name `y`") + case None => fail("Expected a field access with the code `x.y`") } } - "`self.x` should correctly create a `this` node field base" in { + "`self.x` should correctly create a `self` node field base" in { // Example from railsgoat val cpg = code(""" @@ -55,17 +40,21 @@ class FieldAccessTests extends RubyCode2CpgFixture { |end |""".stripMargin) - inside(cpg.call.name("sick_days_earned").l) { + inside(cpg.fieldAccess.code("self.sick_days_earned").l) { case sickDays :: _ => sickDays.code shouldBe "self.sick_days_earned" - sickDays.name shouldBe "sick_days_earned" - sickDays.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + sickDays.name shouldBe Operators.fieldAccess + sickDays.methodFullName shouldBe Operators.fieldAccess + sickDays.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH inside(sickDays.argument.l) { - case (self: Identifier) :: Nil => + case (self: Identifier) :: (sickDaysId: FieldIdentifier) :: Nil => self.name shouldBe "self" self.code shouldBe "self" self.typeFullName should endWith("PaidTimeOff") + + sickDaysId.canonicalName shouldBe "@sick_days_earned" + sickDaysId.code shouldBe "sick_days_earned" case xs => fail(s"Expected exactly two field access arguments, instead got [${xs.code.mkString(", ")}]") } case Nil => fail("Expected at least one call with `self` base, but got none.") @@ -83,8 +72,8 @@ class FieldAccessTests extends RubyCode2CpgFixture { | end |end | - |Base64::decode64 # self.Base64.decode64() - |Baz::func1 # self.Baz.func1() + |Base64::decode64() # self.Base64.decode64() + |Baz::func1() # self.Baz.func1() | |# self.Foo = TYPE_REF Foo |class Foo @@ -105,7 +94,7 @@ class FieldAccessTests extends RubyCode2CpgFixture { bazAssign.code shouldBe "self.Baz" val bazTypeRef = baz.argument(2).asInstanceOf[TypeRef] - bazTypeRef.typeFullName shouldBe "Test0.rb:::program.Baz" + bazTypeRef.typeFullName shouldBe s"Test0.rb:$Main.Baz" bazTypeRef.code shouldBe "module Baz (...)" val fooAssign = foo.argument(1).asInstanceOf[Call] @@ -113,27 +102,26 @@ class FieldAccessTests extends RubyCode2CpgFixture { fooAssign.code shouldBe "self.Foo" val fooTypeRef = foo.argument(2).asInstanceOf[TypeRef] - fooTypeRef.typeFullName shouldBe "Test0.rb:::program.Foo" + fooTypeRef.typeFullName shouldBe s"Test0.rb:$Main.Foo" fooTypeRef.code shouldBe "class Foo (...)" case _ => fail(s"Expected two type ref assignments on the module level") } } "give external type accesses on script-level the `self.` base" in { - val call = cpg.method.isModule.call.codeExact("Base64::decode64").head + val call = cpg.method.isModule.call.nameExact("decode64").head call.name shouldBe "decode64" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Base64" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Base64.decode64" + receiver.code shouldBe "( = Base64).decode64" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Base64" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Base64" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "decode64" @@ -141,20 +129,19 @@ class FieldAccessTests extends RubyCode2CpgFixture { } "give internal type accesses on script-level the `self.` base" in { - val call = cpg.method.isModule.call.codeExact("Baz::func1").head + val call = cpg.method.isModule.call.nameExact("func1").head call.name shouldBe "func1" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Baz" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Baz.func1" + receiver.code shouldBe "( = Baz).func1" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Baz" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Baz" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func1" @@ -186,17 +173,16 @@ class FieldAccessTests extends RubyCode2CpgFixture { val call = cpg.method.nameExact("func").call.nameExact("func1").head call.name shouldBe "func1" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Baz" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Baz.func1" + receiver.code shouldBe "( = Baz).func1" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Baz" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Baz" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func1" @@ -213,7 +199,7 @@ class FieldAccessTests extends RubyCode2CpgFixture { | end | module C | # TYPE_REF A B func - | A::B::func + | A::B::func() | end | end |end @@ -222,23 +208,24 @@ class FieldAccessTests extends RubyCode2CpgFixture { "create `TYPE_REF` targets for the field accesses" in { val call = cpg.call.nameExact("func").head - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "A::B" - - base.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe "Test0.rb:::program.A" - base.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "A::B.func" + receiver.code shouldBe "( = A::B).func" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "A::B" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = A::B" + + selfArg1.argument(1).asInstanceOf[Identifier].code shouldBe s"" + + val abRhs = selfArg1.argument(2).asInstanceOf[Call] + abRhs.code shouldBe "A::B" - selfArg1.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe "Test0.rb:::program.A" - selfArg1.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" + abRhs.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.A" + abRhs.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HashTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HashTests.scala index 2cd5e4caa398..682edeee184b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HashTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HashTests.scala @@ -1,7 +1,7 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.passes.Defines.RubyOperators -import io.joern.rubysrc2cpg.passes.GlobalTypes.{builtinPrefix, kernelPrefix} +import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, TypeRef} @@ -30,7 +30,7 @@ class HashTests extends RubyCode2CpgFixture { val List(assocCall) = hashCall.inCall.astSiblings.assignment.l val List(x, one) = assocCall.argument.l - x.code shouldBe "[x]" + x.code shouldBe "[:x]" one.code shouldBe "1" } @@ -64,13 +64,13 @@ class HashTests extends RubyCode2CpgFixture { regexp.code shouldBe "/(eu|us)/" } - "`{:x : /(eu|us)/}` is represented by a `hashInitializer` operator call" in { + "`{x: /(eu|us)/}` is represented by a `hashInitializer` operator call" in { val cpg = code(""" - |{:x : /(eu|us)/} + |{x: /(eu|us)/} |""".stripMargin) val List(hashCall) = cpg.call.name(RubyOperators.hashInitializer).l - hashCall.code shouldBe "{:x : /(eu|us)/}" + hashCall.code shouldBe "{x: /(eu|us)/}" hashCall.lineNumber shouldBe Some(2) val List(assocCall) = hashCall.inCall.astSiblings.assignment.l @@ -81,8 +81,8 @@ class HashTests extends RubyCode2CpgFixture { "Inclusive Range of primitive ordinal type should expand in hash key" in { val cpg = code(""" - |{1..3:"abc", 4..5:"ade"} - |{'a'..'c': "abc"} + |{1..3 => "abc", 4..5 => "ade"} + |{'a'..'c' => "abc"} |""".stripMargin) inside(cpg.call.name(RubyOperators.hashInitializer).l) { @@ -101,7 +101,7 @@ class HashTests extends RubyCode2CpgFixture { lhs.name shouldBe Operators.indexAccess rhs.code shouldBe "\"abc\"" - rhs.typeFullName shouldBe s"$kernelPrefix.String" + rhs.typeFullName shouldBe Defines.prefixAsCoreType("String") case _ => fail("Expected LHS and RHS after lowering") } @@ -110,7 +110,7 @@ class HashTests extends RubyCode2CpgFixture { lhs.name shouldBe Operators.indexAccess rhs.code shouldBe "\"ade\"" - rhs.typeFullName shouldBe s"$kernelPrefix.String" + rhs.typeFullName shouldBe Defines.prefixAsCoreType("String") } case _ => fail("Expected 5 calls (one per item in range)") } @@ -127,19 +127,19 @@ class HashTests extends RubyCode2CpgFixture { lhs.name shouldBe Operators.indexAccess rhs.code shouldBe "\"abc\"" - rhs.typeFullName shouldBe s"$kernelPrefix.String" + rhs.typeFullName shouldBe Defines.prefixAsCoreType("String") case _ => fail("Expected LHS and RHS after lowering") } case _ => fail("Expected 3 calls (one per item in range)") } - case _ => fail("Expected one hash initializer function") + case _ => fail("Expected two hash initializer functions") } } "Exclusive Range of primitive ordinal type should expand in hash key" in { val cpg = code(""" - |{1...3:"abc"} + |{1...3 => "abc"} |""".stripMargin) inside(cpg.call.name(RubyOperators.hashInitializer).l) { @@ -157,7 +157,7 @@ class HashTests extends RubyCode2CpgFixture { "Non-Primitive ordinal type should not expand in hash key" in { val cpg = code(""" - |{:a...:b:"a"} + |{:a...:b => "a"} |""".stripMargin) inside(cpg.call.name(RubyOperators.hashInitializer).l) { @@ -175,7 +175,7 @@ class HashTests extends RubyCode2CpgFixture { case _ => fail("Expected range operator for non-primitive range key") } - rhs.typeFullName shouldBe s"$kernelPrefix.String" + rhs.typeFullName shouldBe Defines.prefixAsCoreType("String") rhs.code shouldBe "\"a\"" case _ => fail("Expected LHS and RHS for association") } @@ -195,8 +195,8 @@ class HashTests extends RubyCode2CpgFixture { case hashCall :: Nil => hashCall.code shouldBe "Hash [1 => \"a\", 2 => \"b\", 3 => \"c\"]" hashCall.lineNumber shouldBe Some(2) - hashCall.methodFullName shouldBe s"$builtinPrefix.Hash.[]" - hashCall.typeFullName shouldBe s"$builtinPrefix.Hash" + hashCall.methodFullName shouldBe s"${Defines.prefixAsCoreType("Hash")}.[]" + hashCall.typeFullName shouldBe Defines.prefixAsCoreType("Hash") inside(hashCall.astChildren.l) { case (_: Call) :: (_: TypeRef) :: (one: Call) :: (two: Call) :: (three: Call) :: Nil => @@ -209,4 +209,63 @@ class HashTests extends RubyCode2CpgFixture { } } + "Splatting argument in hash" in { + val cpg = code(""" + |a = {**x, **y} + |""".stripMargin) + + inside(cpg.call.name(RubyOperators.hashInitializer).l) { + case hashCall :: Nil => + val List(xSplatCall, ySplatCall) = hashCall.inCall.astSiblings.isCall.l + xSplatCall.code shouldBe "**x" + xSplatCall.methodFullName shouldBe RubyOperators.splat + + ySplatCall.code shouldBe "**y" + ySplatCall.methodFullName shouldBe RubyOperators.splat + case xs => fail(s"Expected call to hashInitializer, [${xs.code.mkString(",")}]") + } + } + + "Function call in hash" in { + val cpg = code(""" + |bar = 200 + |a = {**foo(bar)} + |""".stripMargin) + + inside(cpg.call.name(RubyOperators.hashInitializer).l) { + case hashInitializer :: Nil => + val List(splatCall) = hashInitializer.inCall.astSiblings.isCall.l + splatCall.code shouldBe "**foo(bar)" + splatCall.name shouldBe RubyOperators.splat + + val List(splatCallArg: Call) = splatCall.argument.l: @unchecked + + splatCallArg.code shouldBe "foo(bar)" + + val List(_, barCallArg) = splatCallArg.argument.l + barCallArg.code shouldBe "bar" + case xs => fail(s"Expected one call for init, got [${xs.code.mkString(",")}]") + } + } + + "Function call without parentheses" in { + val cpg = code(""" + |a = {**(foo 13)} + |""".stripMargin) + + inside(cpg.call.name(RubyOperators.hashInitializer).l) { + case hashInitializer :: Nil => + val List(splatCall) = hashInitializer.inCall.astSiblings.isCall.l + splatCall.code shouldBe "**(foo 13)" + splatCall.name shouldBe RubyOperators.splat + + val List(splatCallArg: Call) = splatCall.argument.l: @unchecked + + splatCallArg.code shouldBe "foo 13" + + val List(selfCallArg, literalCallArg) = splatCallArg.argument.l + literalCallArg.code shouldBe "13" + case xs => fail(s"Expected one call for init, got [${xs.code.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HereDocTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HereDocTests.scala index 748335043856..747b78324dd2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HereDocTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/HereDocTests.scala @@ -1,8 +1,8 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal, Local, Method, Return} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class HereDocTests extends RubyCode2CpgFixture { @@ -26,7 +26,7 @@ class HereDocTests extends RubyCode2CpgFixture { localAst.code shouldBe "a" callAst.code shouldBe "a = 10" - literalAst.typeFullName shouldBe s"$kernelPrefix.String" + literalAst.typeFullName shouldBe Defines.prefixAsCoreType("String") returnAst.code shouldBe "a" case _ => @@ -55,7 +55,7 @@ class HereDocTests extends RubyCode2CpgFixture { inside(assignmentCall.argument.l) { case lhsArg :: (rhsArg: Literal) :: Nil => lhsArg.code shouldBe "a" - rhsArg.typeFullName shouldBe s"$kernelPrefix.String" + rhsArg.typeFullName shouldBe Defines.prefixAsCoreType("String") case _ => fail("Expected LHS and RHS for assignment") } case _ => fail("Expected call for assignment") @@ -65,7 +65,7 @@ class HereDocTests extends RubyCode2CpgFixture { } } - "HereDoc as a function argument" ignore { + "HereDoc as a function argument" should { val cpg = code(""" |def foo(arg) | bar(arg, <<-SOME_HEREDOC, arg + 1) @@ -74,7 +74,11 @@ class HereDocTests extends RubyCode2CpgFixture { |end |""".stripMargin) - // TODO: This creates a syntax error + "create a string literal in the 2nd argument position" in { + val barCall = cpg.call("bar").head + val hereDoc = barCall.argument(2).asInstanceOf[Literal] + hereDoc.code shouldBe " inside here doc\n" + } } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala index 5551d82ec6b1..980f9d4abc9c 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala @@ -1,10 +1,12 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.{Initialize, Main} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.{ImplicitRequirePass, ImportsPass, TypeImportInfo} +import io.shiftleft.codepropertygraph.generated.DispatchTypes +import io.shiftleft.codepropertygraph.generated.nodes.Literal import io.shiftleft.semanticcpg.language.* -import io.joern.rubysrc2cpg.RubySrc2Cpg -import io.joern.rubysrc2cpg.Config -import scala.util.{Success, Failure} import org.scalatest.Inspectors class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with Inspectors { @@ -21,6 +23,42 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In call.argument.where(_.argumentIndexGt(0)).code.l shouldBe List("'test'") } + "`require_relative 'test'` is a CALL node with an IMPORT node pointing to it" in { + val cpg = code(""" + |require_relative 'test' + |""".stripMargin) + val List(importNode) = cpg.imports.l + importNode.importedEntity shouldBe Some("test") + importNode.importedAs shouldBe Some("test") + val List(call) = importNode.call.l + call.callee.name.l shouldBe List("require_relative") + call.argument.where(_.argumentIndexGt(0)).code.l shouldBe List("'test'") + } + + "`load 'test'` is a CALL node with an IMPORT node pointing to it" in { + val cpg = code(""" + |load 'test' + |""".stripMargin) + val List(importNode) = cpg.imports.l + importNode.importedEntity shouldBe Some("test") + importNode.importedAs shouldBe Some("test") + val List(call) = importNode.call.l + call.callee.name.l shouldBe List("load") + call.argument.where(_.argumentIndexGt(0)).code.l shouldBe List("'test'") + } + + "`require_all 'test'` is a CALL node with an IMPORT node pointing to it" in { + val cpg = code(""" + |require_all 'test' + |""".stripMargin) + val List(importNode) = cpg.imports.l + importNode.importedEntity shouldBe Some("test") + importNode.importedAs shouldBe Some("test") + val List(call) = importNode.call.l + call.callee.name.l shouldBe List("require_all") + call.argument.where(_.argumentIndexGt(0)).code.l shouldBe List("'test'") + } + "`begin require 'test' rescue LoadError end` has a CALL node with an IMPORT node pointing to it" in { val cpg = code(""" |begin @@ -59,8 +97,15 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In ) val List(newCall) = - cpg.method.name(":program").filename("t1.rb").ast.isCall.methodFullName(".*:initialize").methodFullName.l - newCall should startWith(s"${path}.rb:") + cpg.method.isModule + .filename("t1.rb") + .ast + .isCall + .dynamicTypeHintFullName + .filter(x => x.startsWith(path) && x.endsWith(Initialize)) + .l + + newCall should startWith(s"$path.rb:") } } @@ -86,11 +131,183 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In |""".stripMargin) val List(methodName) = - cpg.method.name("bar").ast.isCall.methodFullName(".*::program\\.(A|B):foo").methodFullName.l + cpg.method.name("bar").ast.isCall.methodFullName(s".*\\.$Main\\.(A|B).foo").methodFullName.l methodName should endWith(s"${moduleName}:foo") } } + "implicitly imported types in base class" should { + val cpg = code( + """ + |class MyController < ApplicationController + |end + |""".stripMargin, + "app/controllers/my_controller.rb" + ) + .moreCode( + """ + |class ApplicationController + |end + |""".stripMargin, + "app/controllers/application_controller.rb" + ) + .moreCode( + """ + |GEM + | remote: https://rubygems.org/ + | specs: + | zeitwerk (2.2.1) + |""".stripMargin, + "Gemfile.lock" + ) + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*my_controller.rb")).toList) { case List(i) => + i.importedAs shouldBe Some("app/controllers/application_controller") + i.importedEntity shouldBe Some("app/controllers/application_controller") + } + } + } + + "implicitly imported types in base class that are qualified names" should { + val cpg = code( + """ + |class MyController < Controllers::ApplicationController + |end + |""".stripMargin, + "app/controllers/my_controller.rb" + ) + .moreCode( + """ + |module Controllers + | class ApplicationController + | end + |end + |""".stripMargin, + "app/controllers/controllers.rb" + ) + .moreCode( + """ + |GEM + | remote: https://rubygems.org/ + | specs: + | zeitwerk (2.2.1) + |""".stripMargin, + "Gemfile.lock" + ) + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*my_controller.rb")).toList) { case List(i) => + i.importedAs shouldBe Some("app/controllers/controllers") + i.importedEntity shouldBe Some("app/controllers/controllers") + } + } + } + + "implicitly imported types that are qualified names in an include statement" should { + val cpg = code( + """ + |module MyController + | include Controllers::ApplicationController + |end + |""".stripMargin, + "app/controllers/my_controller.rb" + ) + .moreCode( + """ + |module Controllers + | class ApplicationController + | end + |end + |""".stripMargin, + "app/controllers/controllers.rb" + ) + .moreCode( + """ + |GEM + | remote: https://rubygems.org/ + | specs: + | zeitwerk (2.2.1) + |""".stripMargin, + "Gemfile.lock" + ) + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*my_controller.rb")).toList) { case List(i) => + i.importedAs shouldBe Some("app/controllers/controllers") + i.importedEntity shouldBe Some("app/controllers/controllers") + } + } + } + + "implicitly imported types in include statement" should { + val cpg = code( + """ + |class MyController + | include ApplicationController + |end + |""".stripMargin, + "app/controllers/my_controller.rb" + ) + .moreCode( + """ + |class ApplicationController + |end + |""".stripMargin, + "app/controllers/application_controller.rb" + ) + .moreCode( + """ + |GEM + | remote: https://rubygems.org/ + | specs: + | zeitwerk (2.2.1) + |""".stripMargin, + "Gemfile.lock" + ) + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*my_controller.rb")).toList) { case List(i) => + i.importedAs shouldBe Some("app/controllers/application_controller") + i.importedEntity shouldBe Some("app/controllers/application_controller") + } + } + } + + "implicitly imported types in extend statement" should { + val cpg = code( + """ + |class MyController + | extend ApplicationController + |end + |""".stripMargin, + "app/controllers/my_controller.rb" + ) + .moreCode( + """ + |class ApplicationController + |end + |""".stripMargin, + "app/controllers/application_controller.rb" + ) + .moreCode( + """ + |GEM + | remote: https://rubygems.org/ + | specs: + | zeitwerk (2.2.1) + |""".stripMargin, + "Gemfile.lock" + ) + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*my_controller.rb")).toList) { case List(i) => + i.importedAs shouldBe Some("app/controllers/application_controller") + i.importedEntity shouldBe Some("app/controllers/application_controller") + } + } + } + "implicitly imported types (common in frameworks like Ruby on Rails)" should { val cpg = code( @@ -107,7 +324,7 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In | end |end | - |B::bar + |B::bar() |""".stripMargin, "bar/B.rb" ) @@ -119,7 +336,9 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In ) .moreCode( """ - |B.bar + |def func() + | B.bar() + |end |""".stripMargin, "Bar.rb" ) @@ -157,6 +376,15 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In cpg.imports.where(_.call.file.name(".*B.rb")).size shouldBe 0 } + "create a `require` call following the simplified format" in { + val require = cpg.call("require").head + require.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + require.methodFullName shouldBe Defines.prefixAsKernelDefined("require") + + val strLit = require.argument(1).asInstanceOf[Literal] + strLit.typeFullName shouldBe Defines.prefixAsCoreType("String") + } + } "Builtin Types type-map" should { @@ -170,20 +398,23 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In "resolve calls to builtin functions" in { inside(cpg.call.methodFullName("(pp|csv).*").l) { - case csvParseCall :: csvTableInitCall :: ppCall :: Nil => - csvParseCall.methodFullName shouldBe "csv.CSV:parse" - ppCall.methodFullName shouldBe "pp.PP:pp" - csvTableInitCall.methodFullName shouldBe "csv.CSV.Table:initialize" - case xs => fail(s"Expected three calls, got [${xs.code.mkString(",")}] instead") + case csvParseCall :: csvTableCall :: ppCall :: Nil => + csvParseCall.methodFullName shouldBe "csv.CSV.parse" + csvTableCall.methodFullName shouldBe "csv.CSV.Table.initialize" + ppCall.methodFullName shouldBe "pp.PP.pp" + case xs => fail(s"Expected calls, got [${xs.code.mkString(",")}] instead") } + + // TODO: fixme - set is empty +// cpg.call(Initialize).dynamicTypeHintFullName.toSet should contain("csv.CSV.Table.initialize") } } "`require_all` on a directory" should { val cpg = code(""" |require_all './dir' - |Module1.foo - |Module2.foo + |Module1.foo() + |Module2.foo() |""".stripMargin) .moreCode( """ @@ -206,8 +437,8 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In "allow the resolution for all modules in that directory" in { cpg.call("foo").methodFullName.l shouldBe List( - "dir/module1.rb:::program.Module1:foo", - "dir/module2.rb:::program.Module2:foo" + s"dir/module1.rb:$Main.Module1.foo", + s"dir/module2.rb:$Main.Module2.foo" ) } } @@ -220,7 +451,7 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In |""".stripMargin) "also create import nodes" in { - inside(cpg.imports.l) { + inside(cpg.imports.l.sortBy { impNode => impNode.isCallForImportIn.head.order }) { case requireAll :: requireRelative :: load :: Nil => requireAll.importedAs shouldBe Option("./dir") requireAll.isWildcard shouldBe Option(true) @@ -245,9 +476,9 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In |require 'file2' |require 'file3' | - |File1::foo # lib/file1.rb::program:foo - |File2::foo # lib/file2.rb::program:foo - |File3::foo # src/file3.rb::program:foo + |File1::foo # lib/file1.rb.
.foo + |File2::foo # lib/file2.rb.
.foo + |File3::foo # src/file3.rb.
.foo |""".stripMargin, "main.rb" ).moreCode( @@ -279,11 +510,40 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In "resolve the calls directly" in { inside(cpg.call.name("foo.*").l) { case foo1 :: foo2 :: foo3 :: Nil => - foo1.methodFullName shouldBe "lib/file1.rb:::program.File1:foo" - foo2.methodFullName shouldBe "lib/file2.rb:::program.File2:foo" - foo3.methodFullName shouldBe "src/file3.rb:::program.File3:foo" + foo1.methodFullName shouldBe s"lib/file1.rb:$Main.File1.foo" + foo2.methodFullName shouldBe s"lib/file2.rb:$Main.File2.foo" + foo3.methodFullName shouldBe s"src/file3.rb:$Main.File3.foo" case xs => fail(s"Expected 3 calls, got [${xs.code.mkString(",")}] instead") } } } } + +class ImportWithAutoloadedExternalGemsTests extends RubyCode2CpgFixture(withPostProcessing = false) { + + "use of a type specified as external" should { + + val cpg = code( + """ + |x = Base64.encode("Hello, world!") + |Bar::Foo.new + |""".stripMargin, + "encoder.rb" + ) + + ImplicitRequirePass(cpg, TypeImportInfo("Base64", "base64") :: TypeImportInfo("Bar", "foobar") :: Nil) + .createAndApply() + ImportsPass(cpg).createAndApply() + + "result in require statement of the file containing the symbol" in { + inside(cpg.imports.where(_.call.file.name(".*encoder.rb")).toList) { case List(i1, i2) => + i1.importedAs shouldBe Some("base64") + i1.importedEntity shouldBe Some("base64") + + i2.importedAs shouldBe Some("foobar") + i2.importedEntity shouldBe Some("foobar") + } + } + } + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IndexAccessTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IndexAccessTests.scala index 3a420aa3265e..8ea753448535 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IndexAccessTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IndexAccessTests.scala @@ -47,4 +47,27 @@ class IndexAccessTests extends RubyCode2CpgFixture { two.code shouldBe "2" } + "Index Access with `.[](index)`" in { + val cpg = code(""" + |class Foo + | def extract_url + | @params.dig(:event, :links)&.first&.[](:url) + | end + |end + |""".stripMargin) + + inside(cpg.call.name(Operators.indexAccess).l) { + case indexCall :: Nil => + indexCall.code shouldBe "@params.dig(:event, :links)&.first&.[](:url)" + + inside(indexCall.argument.l) { + case target :: index :: Nil => + target.code shouldBe "( = @params.dig(:event, :links))&.first" + index.code shouldBe ":url" + case xs => fail(s"Expected target and index, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one index access, got [${xs.code.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/LiteralTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/LiteralTests.scala index 9b579319aaa4..424c81eb6233 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/LiteralTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/LiteralTests.scala @@ -1,8 +1,9 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines as RubyDefines import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix +import io.shiftleft.codepropertygraph.generated.nodes.Literal import io.shiftleft.semanticcpg.language.* class LiteralTests extends RubyCode2CpgFixture { @@ -15,7 +16,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "123" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Integer" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") } "`3.14` is represented by a LITERAL node" in { @@ -26,7 +27,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "3.14" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Float" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Float") } "`3e10` is represented by a LITERAL node" in { @@ -37,7 +38,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "3e10" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Float" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Float") } "`12e-10` is represented by a LITERAL node" in { @@ -48,7 +49,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "12e-10" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Float" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Float") } "`0b01` is represented by a LITERAL node" in { @@ -59,7 +60,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "0b01" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Integer" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") } "`0xabc` is represented by a LITERAL node" in { @@ -70,7 +71,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "0xabc" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Integer" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") } "`true` is represented by a LITERAL node" in { @@ -81,7 +82,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "true" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.TrueClass" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("TrueClass") } "`false` is represented by a LITERAL node" in { @@ -92,7 +93,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "false" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.FalseClass" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("FalseClass") } "`nil` is represented by a LITERAL node" in { @@ -103,7 +104,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "nil" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.NilClass" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("NilClass") } "`'hello'` is represented by a LITERAL node" in { @@ -114,18 +115,26 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "'hello'" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } - "`'x' 'y' 'z'` is represented by a LITERAL node" in { + "`'x' 'y' 'z'` is represented by a dynamic literal node call" in { val cpg = code(""" |'x' 'y' 'z' |""".stripMargin) - val List(literal) = cpg.literal.l - literal.code shouldBe "'x' 'y' 'z'" - literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + val List(dynamicLitCall) = cpg.call.methodFullNameExact(Operators.formatString).l: @unchecked + dynamicLitCall.code shouldBe "'x' 'y' 'z'" + dynamicLitCall.methodFullName shouldBe Operators.formatString + + inside(dynamicLitCall.argument.astChildren.l) { case (x: Literal) :: (y: Literal) :: (z: Literal) :: Nil => + x.code shouldBe "'x'" + x.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") + y.code shouldBe "'y'" + y.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") + z.code shouldBe "'z'" + z.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") + } } "`\"hello\"` is represented by a LITERAL node" in { @@ -136,7 +145,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "\"hello\"" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } "`%q(hello)` is represented by a LITERAL node" in { @@ -147,7 +156,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "%q(hello)" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } "`%Q(hello world)` is represented by a LITERAL node" in { @@ -158,7 +167,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "%Q(hello world)" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } "`%(foo \"bar\" baz)` is represented by a LITERAL node" in { @@ -169,7 +178,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "%(foo \"bar\" baz)" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } """`%q<\n...\n>` is represented by a LITERAL node""" in { @@ -180,15 +189,15 @@ class LiteralTests extends RubyCode2CpgFixture { |> |""".stripMargin) - val List(literal) = cpg.literal.l - literal.code shouldBe - """%q< - |xyz - |123 - |>""".stripMargin - literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.String" - + val List(firstLine, xyz, one23) = cpg.literal.l + firstLine.code.trim shouldBe "" + firstLine.lineNumber shouldBe Some(2) + xyz.code.trim shouldBe "xyz" + xyz.lineNumber shouldBe Some(3) + xyz.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") + one23.code.trim shouldBe "123" + one23.lineNumber shouldBe Some(4) + one23.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") } "`:symbol` is represented by a LITERAL node" in { @@ -199,7 +208,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe ":symbol" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Symbol" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Symbol") } "`:'symbol'` is represented by a LITERAL node" in { @@ -210,7 +219,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe ":'symbol'" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Symbol" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Symbol") } "`/(eu|us)/` is represented by a LITERAL node" in { @@ -221,7 +230,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "/(eu|us)/" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Regexp" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Regexp") } "`/fedora|el-|centos/` is represented by a LITERAL node" in { @@ -232,7 +241,7 @@ class LiteralTests extends RubyCode2CpgFixture { val List(literal) = cpg.literal.l literal.code shouldBe "/fedora|el-|centos/" literal.lineNumber shouldBe Some(2) - literal.typeFullName shouldBe s"$kernelPrefix.Regexp" + literal.typeFullName shouldBe RubyDefines.prefixAsCoreType("Regexp") } "`/#{os_version_regex}/` is represented by a CALL node with a string format method full name" in { @@ -244,8 +253,26 @@ class LiteralTests extends RubyCode2CpgFixture { val List(formatValueCall) = cpg.call.code("/#.*").l formatValueCall.code shouldBe "/#{os_version_regex}/" formatValueCall.lineNumber shouldBe Some(3) - formatValueCall.typeFullName shouldBe s"$kernelPrefix.Regexp" + formatValueCall.typeFullName shouldBe RubyDefines.prefixAsCoreType("Regexp") formatValueCall.methodFullName shouldBe Operators.formatString } + "-> Lambda literal" in { + val cpg = code(""" + |-> (a, *b, &c) {} + |""".stripMargin) + + inside(cpg.method.isLambda.l) { + case lambdaLiteral :: Nil => + inside(lambdaLiteral.parameter.l) { + case _ :: aParam :: bParam :: cParam :: Nil => + aParam.code shouldBe "a" + bParam.code shouldBe "*b" + cParam.code shouldBe "&c" + case xs => fail(s"Expected four parameters, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one lambda, got [${xs.name.mkString(",")}]") + } + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala index d2cc7fd8775a..82b2e44f0837 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala @@ -1,13 +1,13 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.Defines.RubyOperators +import io.joern.rubysrc2cpg.passes.Defines as RubyDefines +import io.joern.rubysrc2cpg.passes.Defines.{Main, RubyOperators} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal, Method, MethodRef, Return} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Operators} import io.shiftleft.semanticcpg.language.* -class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { +class MethodReturnTests extends RubyCode2CpgFixture { "implicit RETURN node for `x * x` exists" in { val cpg = code(""" @@ -72,7 +72,7 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { r.lineNumber shouldBe Some(3) val List(c: Call) = r.astChildren.isCall.l - c.methodFullName shouldBe s"$kernelPrefix:puts" + c.methodFullName shouldBe RubyDefines.prefixAsKernelDefined("puts") c.lineNumber shouldBe Some(3) c.code shouldBe "puts x" } @@ -103,8 +103,8 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { r.code shouldBe "[]" r.lineNumber shouldBe Some(3) - val List(c: Call) = r.astChildren.isCall.l - c.methodFullName shouldBe Operators.arrayInitializer + val List(arr: Block) = r.astChildren.isBlock.l + arr.code shouldBe "[]" } "implicit RETURN node for index access exists" in { @@ -156,7 +156,7 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { val List(s: Literal) = r.astChildren.isLiteral.l s.code shouldBe ":g" - s.typeFullName shouldBe s"$kernelPrefix.Symbol" + s.typeFullName shouldBe RubyDefines.prefixAsCoreType("Symbol") } "explicit RETURN node for `\"\"` exists" in { @@ -192,14 +192,14 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { val List(twenty: Literal) = return20.astChildren.l: @unchecked twenty.code shouldBe "20" twenty.lineNumber shouldBe Some(4) - twenty.typeFullName shouldBe s"$kernelPrefix.Integer" + twenty.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") returnNil.code shouldBe "return nil" returnNil.lineNumber shouldBe Some(3) val List(nil: Literal) = returnNil.astChildren.l: @unchecked nil.code shouldBe "nil" nil.lineNumber shouldBe Some(3) - nil.typeFullName shouldBe s"$kernelPrefix.NilClass" + nil.typeFullName shouldBe RubyDefines.prefixAsCoreType("NilClass") case xs => fail(s"Expected exactly two return nodes, instead got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected exactly one method with the name `f`, instead got [${xs.code.mkString(",")}]") @@ -227,14 +227,14 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { val List(twenty: Literal) = return20.astChildren.l: @unchecked twenty.code shouldBe "20" twenty.lineNumber shouldBe Some(4) - twenty.typeFullName shouldBe s"$kernelPrefix.Integer" + twenty.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") return40.code shouldBe "40" return40.lineNumber shouldBe Some(6) val List(forty: Literal) = return40.astChildren.l: @unchecked forty.code shouldBe "40" forty.lineNumber shouldBe Some(6) - forty.typeFullName shouldBe s"$kernelPrefix.Integer" + forty.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") case xs => fail(s"Expected exactly two return nodes, instead got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected exactly one method with the name `f`, instead got [${xs.code.mkString(",")}]") @@ -297,14 +297,14 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { val List(twenty: Literal) = return20.astChildren.l: @unchecked twenty.code shouldBe "20" twenty.lineNumber shouldBe Some(2) - twenty.typeFullName shouldBe s"$kernelPrefix.Integer" + twenty.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") return40.code shouldBe "40" return40.lineNumber shouldBe Some(2) val List(forty: Literal) = return40.astChildren.l: @unchecked forty.code shouldBe "40" forty.lineNumber shouldBe Some(2) - forty.typeFullName shouldBe s"$kernelPrefix.Integer" + forty.typeFullName shouldBe RubyDefines.prefixAsCoreType("Integer") case xs => fail(s"Expected exactly two return nodes, instead got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected exactly one method with the name `f`, instead got [${xs.code.mkString(",")}]") @@ -339,7 +339,7 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { } } - "implicit RETURN node for ASSOCIATION" in { + "implicit RETURN node for super call" in { val cpg = code(""" |def j | super(only: ["a"]) @@ -350,18 +350,18 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { case jMethod :: Nil => inside(jMethod.methodReturn.toReturn.l) { case retAssoc :: Nil => - retAssoc.code shouldBe "only: [\"a\"]" + retAssoc.code shouldBe "super(only: [\"a\"])" val List(call: Call) = retAssoc.astChildren.l: @unchecked - call.name shouldBe RubyOperators.association - call.code shouldBe "only: [\"a\"]" + call.name shouldBe "super" + call.code shouldBe "super(only: [\"a\"])" case xs => fail(s"Expected exactly one return nodes, instead got [${xs.code.mkString(",")}]") } case _ => fail("Only one method expected") } } - "implict RETURN node for RubyCallWithBlock" should { + "implicit RETURN node for RubyCallWithBlock" should { val cpg = code(""" | def foo &block | puts block.call @@ -380,7 +380,7 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { inside(bar.astChildren.collectAll[Method].l) { case closureMethod :: Nil => closureMethod.name shouldBe "0" - closureMethod.fullName shouldBe "Test0.rb:::program:bar:0" + closureMethod.fullName shouldBe s"Test0.rb:$Main.bar.0" case xs => fail(s"Expected closure method, but found ${xs.code.mkString(", ")} instead") } @@ -388,12 +388,12 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { case barReturn :: Nil => inside(barReturn.astChildren.l) { case (returnCall: Call) :: Nil => - returnCall.code.replaceAll("\n", "") shouldBe "foo do \"hello\" end" + returnCall.code should startWith("foo") returnCall.name shouldBe "foo" - val List(_, arg: MethodRef) = returnCall.argument.l: @unchecked - arg.methodFullName shouldBe "Test0.rb:::program:bar:0" + val List(_, arg: TypeRef) = returnCall.argument.l: @unchecked + arg.typeFullName shouldBe s"Test0.rb:$Main.bar.0&Proc" case xs => fail(s"Expected one call for return, but found ${xs.code.mkString(", ")} instead") } @@ -421,22 +421,99 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { "implicit return of a heredoc should return a literal" in { val cpg = code(""" - |def custom_fact_content(key='custom_fact', value='custom_value', *args) + |def foo | <<-EOM - | Facter.add('#{key}') do - | setcode {'#{value}'} - | #{args.empty? ? '' : args.join('\n')} - | end - | EOM + | puts "hello" + | EOM |end |""".stripMargin) - inside(cpg.method.nameExact("custom_fact_content").methodReturn.toReturn.astChildren.l) { + inside(cpg.method.nameExact("foo").methodReturn.toReturn.astChildren.l) { case (heredoc: Literal) :: Nil => - heredoc.typeFullName shouldBe s"$kernelPrefix.String" - heredoc.code should startWith("<<-EOM") + heredoc.typeFullName shouldBe RubyDefines.prefixAsCoreType("String") + heredoc.code should startWith(" puts \"hello\"") case xs => fail(s"Expected a single literal node, instead got [${xs.code.mkString(", ")}]") } } + "a return in an expression position without arguments should generate a return node with no children" in { + val cpg = code(""" + |def foo + | return unless baz() + | bar() + |end + |""".stripMargin) + + inside(cpg.method.nameExact("foo").ast.isReturn.l) { + case ret1 :: ret2 :: ret3 :: Nil => + ret1.code shouldBe "return nil" + ret1.astChildren.size shouldBe 1 + ret1.astParent.astParent.code shouldBe "return unless baz()" + + ret2.code shouldBe "return" + ret2.astChildren.size shouldBe 0 + ret2.astParent.astParent.code shouldBe "return unless baz()" + + ret3.code shouldBe "bar()" + case xs => fail(s"Expected 3 return nodes, got ${xs.size}") + } + } + + "a return with multiple values" in { + val cpg = code(""" + |def foo + | return 1, :z => 1 + |end + |""".stripMargin) + + inside(cpg.method.nameExact("foo").ast.isReturn.headOption) { + case Some(ret) => + val List(oneLiteral: Literal, zAssoc: Call) = ret.astChildren.l: @unchecked + oneLiteral.code shouldBe "1" + zAssoc.code shouldBe ":z => 1" + zAssoc.methodFullName shouldBe RubyOperators.association + + inside(zAssoc.argument.l) { + case (key: Literal) :: (value: Literal) :: Nil => + key.code shouldBe ":z" + value.code shouldBe "1" + case xs => fail(s"Expected two args, got ${xs.code.mkString(",")}") + } + case None => fail(s"Expected at least one return node") + } + } + + "Return with method invocation without parentheses" in { + val cpg = code(""" + |def foo() + | return render json: {}, status: :internal_server_error unless success + |end + |""".stripMargin) + + inside(cpg.method.name("foo").body.astChildren.isControlStructure.l) { + case ifNode :: Nil => + ifNode.controlStructureType shouldBe ControlStructureTypes.IF + + val List(successCall: Call) = ifNode.condition.l: @unchecked + successCall.code shouldBe "self.success" + + val List(ifReturnFalse: Return) = ifNode.whenFalse.isBlock.astChildren.isReturn.l + ifReturnFalse.code shouldBe "return render json: {}, status: :internal_server_error" + + val List(_, jsonArg: Block, statusArg: Literal) = ifReturnFalse.astChildren.isCall.argument.l: @unchecked + jsonArg.argumentName shouldBe Some("json") + jsonArg.code shouldBe "" + + val List(_: Identifier, hashInitCall: Call) = jsonArg.astChildren.isCall.argument.l: @unchecked + hashInitCall.methodFullName shouldBe RubyOperators.hashInitializer + + statusArg.argumentName shouldBe Some("status") + statusArg.code shouldBe ":internal_server_error" + + val List(ifReturnTrue: Return) = ifNode.whenTrue.isBlock.astChildren.isReturn.l + ifReturnTrue.code shouldBe "return nil" + + case xs => fail(s"Expected two method returns, got [${xs.code.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala index 4e6b8bc2daa7..0f6723ca9e37 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala @@ -1,23 +1,25 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.passes.Defines as RDefines -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix +import io.joern.rubysrc2cpg.passes.Defines.{Main, RubyOperators} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes, Operators} import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} class MethodTests extends RubyCode2CpgFixture { "`def f(x) = 1`" should { val cpg = code(""" |def f(x) = 1 + |f(1) |""".stripMargin) "be represented by a METHOD node" in { val List(f) = cpg.method.name("f").l - f.fullName shouldBe "Test0.rb:::program:f" + f.fullName shouldBe s"Test0.rb:$Main.f" f.isExternal shouldBe false f.lineNumber shouldBe Some(2) f.numberOfLines shouldBe 1 @@ -26,21 +28,33 @@ class MethodTests extends RubyCode2CpgFixture { x.index shouldBe 1 x.isVariadic shouldBe false x.lineNumber shouldBe Some(2) + + val List(fSelf) = f.parameter.name(RDefines.Self).l + fSelf.index shouldBe 0 + fSelf.isVariadic shouldBe false + fSelf.lineNumber shouldBe Some(2) + fSelf.referencingIdentifiers.size shouldBe 0 + + val List(mSelf) = cpg.method.isModule.parameter.name(RDefines.Self).l + mSelf.index shouldBe 0 + mSelf.isVariadic shouldBe false + mSelf.lineNumber shouldBe Some(2) + mSelf.referencingIdentifiers.size shouldBe 3 } "have a corresponding bound type" in { val List(fType) = cpg.typeDecl("f").l - fType.fullName shouldBe "Test0.rb:::program:f" + fType.fullName shouldBe s"Test0.rb:$Main.f" fType.code shouldBe "def f(x) = 1" - fType.astParentFullName shouldBe "Test0.rb:::program" + fType.astParentFullName shouldBe s"Test0.rb:$Main" fType.astParentType shouldBe NodeTypes.METHOD val List(fMethod) = fType.iterator.boundMethod.l - fType.fullName shouldBe "Test0.rb:::program:f" + fType.fullName shouldBe s"Test0.rb:$Main.f" } "create a 'fake' method for the file" in { - val List(m) = cpg.method.nameExact(RDefines.Program).l - m.fullName shouldBe "Test0.rb:::program" + val List(m) = cpg.method.nameExact(RDefines.Main).l + m.fullName shouldBe s"Test0.rb:$Main" m.isModule.nonEmpty shouldBe true } } @@ -54,7 +68,7 @@ class MethodTests extends RubyCode2CpgFixture { val List(f) = cpg.method.name("f").l - f.fullName shouldBe "Test0.rb:::program:f" + f.fullName shouldBe s"Test0.rb:$Main.f" f.isExternal shouldBe false f.lineNumber shouldBe Some(2) f.numberOfLines shouldBe 3 @@ -68,7 +82,7 @@ class MethodTests extends RubyCode2CpgFixture { val List(f) = cpg.method.name("f").l - f.fullName shouldBe "Test0.rb:::program:f" + f.fullName shouldBe s"Test0.rb:$Main.f" f.isExternal shouldBe false f.lineNumber shouldBe Some(2) f.numberOfLines shouldBe 1 @@ -86,7 +100,7 @@ class MethodTests extends RubyCode2CpgFixture { val List(f) = cpg.method.name("f").l - f.fullName shouldBe "Test0.rb:::program:f" + f.fullName shouldBe s"Test0.rb:$Main.f" f.isExternal shouldBe false f.lineNumber shouldBe Some(2) f.numberOfLines shouldBe 1 @@ -193,7 +207,7 @@ class MethodTests extends RubyCode2CpgFixture { inside(funcF.parameter.l) { case thisParam :: xParam :: Nil => thisParam.code shouldBe RDefines.Self - thisParam.typeFullName shouldBe "Test0.rb:::program.C" + thisParam.typeFullName shouldBe s"Test0.rb:$Main.C" thisParam.index shouldBe 0 thisParam.isVariadic shouldBe false @@ -227,7 +241,7 @@ class MethodTests extends RubyCode2CpgFixture { inside(funcF.parameter.l) { case thisParam :: xParam :: Nil => thisParam.code shouldBe RDefines.Self - thisParam.typeFullName shouldBe "Test0.rb:::program.C" + thisParam.typeFullName shouldBe s"Test0.rb:$Main.C" thisParam.index shouldBe 0 thisParam.isVariadic shouldBe false @@ -258,7 +272,7 @@ class MethodTests extends RubyCode2CpgFixture { xs.name shouldBe "xs" xs.code shouldBe "*xs" xs.isVariadic shouldBe true - xs.typeFullName shouldBe s"$kernelPrefix.Array" + xs.typeFullName shouldBe RDefines.prefixAsCoreType("Array") case xs => fail(s"Expected `foo` to have one parameter, got [${xs.code.mkString(", ")}]") } } @@ -269,7 +283,7 @@ class MethodTests extends RubyCode2CpgFixture { ys.name shouldBe "ys" ys.code shouldBe "**ys" ys.isVariadic shouldBe true - ys.typeFullName shouldBe s"$kernelPrefix.Hash" + ys.typeFullName shouldBe RDefines.prefixAsCoreType("Hash") case xs => fail(s"Expected `foo` to have one parameter, got [${xs.code.mkString(", ")}]") } } @@ -308,18 +322,23 @@ class MethodTests extends RubyCode2CpgFixture { x.name shouldBe "x" bar.name shouldBe "bar=" - xeq.parameter.name.l shouldBe bar.parameter.name.l + bar.parameter.name.l shouldBe List("self", "args", "&block") // bar forwards parameters to a call to the aliased method inside(bar.call.name("x=").l) { case barCall :: Nil => inside(barCall.argument.l) { - case _ :: (z: Identifier) :: Nil => - z.name shouldBe "z" - z.argumentIndex shouldBe 1 + case _ :: (args: Call) :: (blockId: Identifier) :: Nil => + args.name shouldBe RubyOperators.splat + args.code shouldBe "*args" + args.argumentIndex shouldBe 1 + + blockId.name shouldBe "&block" + blockId.code shouldBe "&block" + blockId.argumentIndex shouldBe 2 case xs => fail(s"Expected a two arguments for the call `x=`, instead got [${xs.code.mkString(",")}]") } - barCall.code shouldBe "x=(z)" + barCall.code shouldBe "x=(*args, &block)" case xs => fail(s"Expected a single call to `bar=`, instead got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected a three virtual methods under `Foo`, instead got [${xs.code.mkString(",")}]") @@ -329,6 +348,60 @@ class MethodTests extends RubyCode2CpgFixture { } } + "aliased methods with `alias_method`" should { + val cpg = code(""" + |class Foo + | def aliasable(bbb) + | puts bbb + | end + | + | alias_method :print_something, :aliasable + | + | def someMethod(aaa) + | print_something(aaa) + | end + |end + | + |""".stripMargin) + + "similarly alias the method as if it were calling `alias`" in { + inside(cpg.typeDecl("Foo").l) { + case foo :: Nil => + inside(foo.method.nameNot(RDefines.Initialize, RDefines.TypeDeclBody).l) { + case a :: p :: s :: Nil => + a.name shouldBe "aliasable" + p.name shouldBe "print_something" + s.name shouldBe "someMethod" + + p.parameter.name.l shouldBe List("self", "args", "&block") + // bar forwards parameters to a call to the aliased method + inside(p.call.name("aliasable").l) { + case aliasableCall :: Nil => + inside(aliasableCall.argument.l) { + case _ :: (args: Call) :: (blockId: Identifier) :: Nil => + args.name shouldBe RubyOperators.splat + args.code shouldBe "*args" + args.argumentIndex shouldBe 1 + + blockId.name shouldBe "&block" + blockId.code shouldBe "&block" + blockId.argumentIndex shouldBe 2 + case xs => + fail( + s"Expected a two arguments for the call `aliasable`, instead got [${xs.code.mkString(",")}]" + ) + } + aliasableCall.code shouldBe "aliasable(*args, &block)" + case xs => fail(s"Expected a single call to `aliasable`, instead got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected a three virtual methods under `Foo`, instead got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected a single type decl for `Foo`, instead got [${xs.code.mkString(",")}]") + } + } + + } + "Singleton Methods for module scope" should { val cpg = code(""" |module F @@ -351,7 +424,7 @@ class MethodTests extends RubyCode2CpgFixture { case thisParam :: xParam :: Nil => thisParam.name shouldBe RDefines.Self thisParam.code shouldBe "F" - thisParam.typeFullName shouldBe "Test0.rb:::program.F" + thisParam.typeFullName shouldBe s"Test0.rb:$Main.F" xParam.name shouldBe "x" case xs => fail(s"Expected two parameters, got ${xs.name.mkString(", ")}") @@ -361,7 +434,7 @@ class MethodTests extends RubyCode2CpgFixture { case thisParam :: xParam :: Nil => thisParam.name shouldBe RDefines.Self thisParam.code shouldBe "F" - thisParam.typeFullName shouldBe "Test0.rb:::program.F" + thisParam.typeFullName shouldBe s"Test0.rb:$Main.F" xParam.name shouldBe "x" xParam.code shouldBe "x" @@ -374,21 +447,35 @@ class MethodTests extends RubyCode2CpgFixture { // TODO: we cannot bind baz as this is a dynamic assignment to `F` which is trickier to determine // Also, double check bindings "have bindings to the singleton module TYPE_DECL" ignore { - cpg.typeDecl.name("F").methodBinding.methodFullName.l shouldBe List("Test0.rb:::program.F:bar") + cpg.typeDecl.name("F").methodBinding.methodFullName.l shouldBe List(s"Test0.rb:$Main.F.bar") } - "baz should not exist in the :program block" in { - inside(cpg.method.name(":program").l) { + "baz should not exist in the
block" in { + inside(cpg.method.isModule.l) { case prog :: Nil => inside(prog.block.astChildren.isMethod.name("baz").l) { case Nil => // passing case case _ => fail("Baz should not exist under program method block") } - case _ => fail("Expected one Method for :program") + case _ => fail("Expected one Method for ") } } } + "Singleton methods binding to an unresolvable variable/type should bind to the next AST parent" in { + val cpg = code(""" + |class C + | def something.foo + | end + |end + |""".stripMargin) + + val foo = cpg.method.nameExact("foo").head + + foo.definingTypeDecl.map(_.name) shouldBe Option("C") + foo.astParent shouldBe cpg.typeDecl("C").head + } + "A Boolean method" should { val cpg = code(""" |def exists? @@ -399,7 +486,7 @@ class MethodTests extends RubyCode2CpgFixture { "be represented by a METHOD node" in { inside(cpg.method.name("exists\\?").l) { case existsMethod :: Nil => - existsMethod.fullName shouldBe "Test0.rb:::program:exists?" + existsMethod.fullName shouldBe s"Test0.rb:$Main.exists?" existsMethod.isExternal shouldBe false inside(existsMethod.methodReturn.cfgIn.l) { @@ -458,7 +545,7 @@ class MethodTests extends RubyCode2CpgFixture { inside(loopMethod.block.astChildren.isControlStructure.l) { case ifStruct :: Nil => inside(ifStruct.astChildren.isBlock.l) { - case breakBlock :: nilBlock :: Nil => + case nilBlock :: breakBlock :: Nil => inside(breakBlock.astChildren.isControlStructure.l) { case breakStruct :: Nil => breakStruct.code shouldBe "break" @@ -509,20 +596,19 @@ class MethodTests extends RubyCode2CpgFixture { |""".stripMargin) "Should be represented as a TRY structure" in { - inside(cpg.method.name("foo").tryBlock.l) { - case tryBlock :: Nil => - tryBlock.controlStructureType shouldBe ControlStructureTypes.TRY - - inside(tryBlock.astChildren.l) { - case body :: ensureBody :: Nil => - body.ast.isLiteral.code.l shouldBe List("1") - body.order shouldBe 1 - - ensureBody.ast.isLiteral.code.l shouldBe List("2") - ensureBody.order shouldBe 3 - case xs => fail(s"Expected body and ensureBody, got ${xs.code.mkString(", ")} instead") - } - case xs => fail(s"Expected one method, found ${xs.method.name.mkString(", ")} instead") + inside(cpg.method.name("foo").controlStructure.l) { + case tryStruct :: emptyElseStruct :: ensureStruct :: Nil => + tryStruct.controlStructureType shouldBe ControlStructureTypes.TRY + val body = tryStruct.astChildren.head + body.ast.isLiteral.code.l shouldBe List("1") + + emptyElseStruct.controlStructureType shouldBe ControlStructureTypes.ELSE + emptyElseStruct.ast.isLiteral.code.l shouldBe List("nil") + + ensureStruct.controlStructureType shouldBe ControlStructureTypes.FINALLY + ensureStruct.ast.isLiteral.code.l shouldBe List("2") + + case xs => fail(s"Expected three structures, got ${xs.code.mkString(",")}") } } } @@ -548,23 +634,23 @@ class MethodTests extends RubyCode2CpgFixture { leftArg.name shouldBe "a" rightArg.name shouldBe "hexdigest" - rightArg.code shouldBe "Digest::MD5.hexdigest(password)" + rightArg.code shouldBe "( = Digest::MD5).hexdigest(password)" - inside(rightArg.argument.l) { - case (md5: Call) :: (passwordArg: Identifier) :: Nil => - md5.name shouldBe Operators.fieldAccess - md5.code shouldBe "Digest::MD5" + val hexDigestFa = rightArg.receiver.head.asInstanceOf[FieldAccess] + hexDigestFa.code shouldBe "( = Digest::MD5).hexdigest" - val md5Base = md5.argument(1).asInstanceOf[Call] - md5.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "MD5" + val tmp1Assign = hexDigestFa.argument(1).asInstanceOf[Assignment] + tmp1Assign.code shouldBe " = Digest::MD5" - md5Base.name shouldBe Operators.fieldAccess - md5Base.code shouldBe "self.Digest" + val md5Fa = tmp1Assign.source.asInstanceOf[FieldAccess] + md5Fa.code shouldBe "( = Digest)::MD5" - md5Base.argument(1).asInstanceOf[Identifier].name shouldBe RDefines.Self - md5Base.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Digest" - case xs => fail(s"Expected identifier and call, got ${xs.code.mkString(", ")} instead") - } + val tmp0Assign = md5Fa.argument(1).asInstanceOf[Assignment] + tmp0Assign.code shouldBe " = Digest" + + val digestFa = tmp0Assign.source.asInstanceOf[FieldAccess] + digestFa.argument(1).asInstanceOf[Identifier].name shouldBe RDefines.Self + digestFa.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Digest" case xs => fail(s"Expected 2 arguments, got ${xs.code.mkString(", ")} instead") } case None => fail("Expected if-condition") @@ -601,7 +687,7 @@ class MethodTests extends RubyCode2CpgFixture { ) "be directly under :program" in { - inside(cpg.method.name(RDefines.Program).filename("t1.rb").assignment.l) { + inside(cpg.method.name(RDefines.Main).filename("t1.rb").assignment.l) { case moduleAssignment :: classAssignment :: methodAssignment :: Nil => moduleAssignment.code shouldBe "self.A = module A (...)" classAssignment.code shouldBe "self.B = class B (...)" @@ -611,7 +697,7 @@ class MethodTests extends RubyCode2CpgFixture { case (lhs: Call) :: (rhs: TypeRef) :: Nil => lhs.code shouldBe "self.A" lhs.name shouldBe Operators.fieldAccess - rhs.typeFullName shouldBe "t1.rb:::program.A" + rhs.typeFullName shouldBe s"t1.rb:$Main.A" case xs => fail(s"Expected lhs and rhs, instead got ${xs.code.mkString(",")}") } @@ -619,7 +705,7 @@ class MethodTests extends RubyCode2CpgFixture { case (lhs: Call) :: (rhs: TypeRef) :: Nil => lhs.code shouldBe "self.B" lhs.name shouldBe Operators.fieldAccess - rhs.typeFullName shouldBe "t1.rb:::program.B" + rhs.typeFullName shouldBe s"t1.rb:$Main.B" case xs => fail(s"Expected lhs and rhs, instead got ${xs.code.mkString(",")}") } @@ -627,8 +713,8 @@ class MethodTests extends RubyCode2CpgFixture { case (lhs: Call) :: (rhs: MethodRef) :: Nil => lhs.code shouldBe "self.c" lhs.name shouldBe Operators.fieldAccess - rhs.methodFullName shouldBe "t1.rb:::program:c" - rhs.typeFullName shouldBe "t1.rb:::program:c" + rhs.methodFullName shouldBe s"t1.rb:$Main.c" + rhs.typeFullName shouldBe s"t1.rb:$Main.c" case xs => fail(s"Expected lhs and rhs, instead got ${xs.code.mkString(",")}") } @@ -637,7 +723,7 @@ class MethodTests extends RubyCode2CpgFixture { } "not be present in other files" in { - inside(cpg.method.name(RDefines.Program).filename("t2.rb").assignment.l) { + inside(cpg.method.name(RDefines.Main).filename("t2.rb").assignment.l) { case classAssignment :: methodAssignment :: Nil => classAssignment.code shouldBe "self.D = class D (...)" methodAssignment.code shouldBe "self.e = def e (...)" @@ -646,7 +732,7 @@ class MethodTests extends RubyCode2CpgFixture { case (lhs: Call) :: (rhs: TypeRef) :: Nil => lhs.code shouldBe "self.D" lhs.name shouldBe Operators.fieldAccess - rhs.typeFullName shouldBe "t2.rb:::program.D" + rhs.typeFullName shouldBe s"t2.rb:$Main.D" case xs => fail(s"Expected lhs and rhs, instead got ${xs.code.mkString(",")}") } @@ -654,8 +740,8 @@ class MethodTests extends RubyCode2CpgFixture { case (lhs: Call) :: (rhs: MethodRef) :: Nil => lhs.code shouldBe "self.e" lhs.name shouldBe Operators.fieldAccess - rhs.methodFullName shouldBe "t2.rb:::program:e" - rhs.typeFullName shouldBe "t2.rb:::program:e" + rhs.methodFullName shouldBe s"t2.rb:$Main.e" + rhs.typeFullName shouldBe s"t2.rb:$Main.e" case xs => fail(s"Expected lhs and rhs, instead got ${xs.code.mkString(",")}") } @@ -664,15 +750,419 @@ class MethodTests extends RubyCode2CpgFixture { } "be placed in order of definition" in { - inside(cpg.method.name(RDefines.Program).filename("t1.rb").block.astChildren.l) { + inside(cpg.method.name(RDefines.Main).filename("t1.rb").block.astChildren.isCall.l) { case (a1: Call) :: (a2: Call) :: (a3: Call) :: (a4: Call) :: (a5: Call) :: Nil => a1.code shouldBe "self.A = module A (...)" - a2.code shouldBe "self::A::" + a2.code shouldBe "( = self::A)::()" a3.code shouldBe "self.B = class B (...)" - a4.code shouldBe "self::B::" + a4.code shouldBe "( = self::B)::()" a5.code shouldBe "self.c = def c (...)" case xs => fail(s"Expected assignments to appear before definitions, instead got [${xs.mkString("\n")}]") } } } + + "Splatting and normal argument" in { + val cpg = code(""" + |def foo(*x, y) + |end + |""".stripMargin) + + inside(cpg.method.name("foo").l) { + case fooMethod :: Nil => + inside(fooMethod.method.parameter.l) { + case selfArg :: splatArg :: normalArg :: Nil => + splatArg.code shouldBe "*x" + splatArg.index shouldBe 1 + + normalArg.code shouldBe "y" + normalArg.index shouldBe 2 + case xs => fail(s"Expected two parameters, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one method, got [${xs.code.mkString(",")}]") + } + } + + "Splatting argument in call" in { + val cpg = code(""" + |def foo(a, b) + |end + | + |x = 1,2 + |foo(*x, y) + |""".stripMargin) + + inside(cpg.call.name("foo").l) { + case fooCall :: Nil => + inside(fooCall.argument.l) { + case selfArg :: xArg :: yArg :: Nil => + xArg.code shouldBe "*x" + yArg.code shouldBe "self.y" + case xs => fail(s"Expected two args, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one call to foo, got [${xs.code.mkString(",")}]") + } + } + + "a nested method declaration inside of a do-block should connect the member node to the bound type decl" in { + val cpg = code(""" + |foo do + | def bar + | end + |end + |""".stripMargin) + + val parentType = cpg.member("bar").typeDecl.head + parentType.isLambda should not be empty + parentType.methodBinding.methodFullName.head should endWith("0") + } + + "a method that is redefined should have a counter suffixed to ensure uniqueness" in { + val cpg = code(""" + |def foo;end + |def bar;end + |def foo;end + |def foo;end + |""".stripMargin) + + cpg.method.name("(foo|bar).*").name.l shouldBe List("foo", "bar", "foo", "foo") + cpg.method.name("(foo|bar).*").fullName.l shouldBe List( + s"Test0.rb:$Main.foo", + s"Test0.rb:$Main.bar", + s"Test0.rb:$Main.foo0", + s"Test0.rb:$Main.foo1" + ) + } + + "MemberCall with a function name the same as a reserved keyword" in { + val cpg = code(""" + |batch.retry!() + |""".stripMargin) + + inside(cpg.call.name(".*retry!").l) { + case batchCall :: Nil => + batchCall.name shouldBe "retry!" + batchCall.code shouldBe "( = batch).retry!()" + + inside(batchCall.receiver.l) { + case (receiverCall: Call) :: Nil => + receiverCall.name shouldBe Operators.fieldAccess + receiverCall.code shouldBe "( = batch).retry!" + + val selfBatch = receiverCall.argument(1).asInstanceOf[Call] + selfBatch.code shouldBe " = batch" + + val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] + retry.code shouldBe "retry!" + + case xs => fail(s"Expected one receiver for call, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one method for batch.retry, got [${xs.code.mkString(",")}]") + } + } + + "Call with :: syntax and reserved keyword" in { + val cpg = code(""" + |batch::retry!() + |""".stripMargin) + + inside(cpg.call.name(".*retry!").l) { + case batchCall :: Nil => + batchCall.name shouldBe "retry!" + batchCall.code shouldBe "( = batch)::retry!()" + + inside(batchCall.receiver.l) { + case (receiverCall: Call) :: Nil => + receiverCall.name shouldBe Operators.fieldAccess + receiverCall.code shouldBe "( = batch).retry!" + + val selfBatch = receiverCall.argument(1).asInstanceOf[Call] + selfBatch.code shouldBe " = batch" + + val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] + retry.code shouldBe "retry!" + + case xs => fail(s"Expected one receiver for call, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one method for batch.retry, got [${xs.code.mkString(",")}]") + } + } + + "Call with reserved keyword as base and call name using . notation" in { + val cpg = code(""" + |retry.retry!() + |""".stripMargin) + + inside(cpg.call.name(".*retry!").l) { + case batchCall :: Nil => + batchCall.name shouldBe "retry!" + batchCall.code shouldBe "( = retry).retry!()" + + inside(batchCall.receiver.l) { + case (receiverCall: Call) :: Nil => + receiverCall.name shouldBe Operators.fieldAccess + receiverCall.code shouldBe "( = retry).retry!" + + val selfBatch = receiverCall.argument(1).asInstanceOf[Call] + selfBatch.code shouldBe " = retry" + + val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] + retry.code shouldBe "retry!" + + case xs => fail(s"Expected one receiver for call, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one method for batch.retry, got [${xs.code.mkString(",")}]") + } + } + + "Call with reserved keyword as base and call name" in { + val cpg = code(""" + |retry::retry!() + |""".stripMargin) + + inside(cpg.call.name(".*retry!").l) { + case batchCall :: Nil => + batchCall.name shouldBe "retry!" + batchCall.code shouldBe "( = retry)::retry!()" + + inside(batchCall.receiver.l) { + case (receiverCall: Call) :: Nil => + receiverCall.name shouldBe Operators.fieldAccess + receiverCall.code shouldBe "( = retry).retry!" + + val selfBatch = receiverCall.argument(1).asInstanceOf[Call] + selfBatch.code shouldBe " = retry" + + val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] + retry.code shouldBe "retry!" + + case xs => fail(s"Expected one receiver for call, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one method for batch.retry, got [${xs.code.mkString(",")}]") + } + } + + "%x should be represented as a call to EXEC" in { + val cpg = code(""" + |%x(ls -l) + |""".stripMargin) + + inside(cpg.call.name(RubyOperators.backticks).l) { + case execCall :: Nil => + execCall.name shouldBe RubyOperators.backticks + inside(execCall.argument.l) { + case selfArg :: lsArg :: Nil => + selfArg.code shouldBe "self" + lsArg.code shouldBe "ls -l" + case xs => fail(s"expected 2 arguments, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected one call to exec, got [${xs.code.mkString(",")}]") + } + } + + "MemberAccessCommand with two parameters" in { + val cpg = code("foo&.bar 1,2") + + inside(cpg.call.name("bar").l) { + case barCall :: Nil => + inside(barCall.argument.l) { + case _ :: (arg1: Literal) :: (arg2: Literal) :: Nil => + arg1.code shouldBe "1" + arg1.typeFullName shouldBe RDefines.prefixAsCoreType(RDefines.Integer) + + arg2.code shouldBe "2" + arg2.typeFullName shouldBe RDefines.prefixAsCoreType(RDefines.Integer) + case xs => fail(s"Expected three args, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one call, got [${xs.code.mkString(",")}]") + } + } + + "Method def in class defined in a namespace" in { + val cpg = code(""" + |class Api::V1::MobileController + | def show + | end + |end + |""".stripMargin) + + inside(cpg.method.name("show").l) { + case showMethod :: Nil => + showMethod.astParentFullName shouldBe "Test0.rb:
.Api.V1.MobileController" + showMethod.astParentType shouldBe NodeTypes.TYPE_DECL + case xs => fail(s"Expected one methood, got ${xs.name.mkString(",")}") + } + } + + "Method def with mandatory arg after splat arg" in { + val cpg = code(""" + |def foo(a=1, *b, c) + |end + |""".stripMargin) + + inside(cpg.method.name("foo").parameter.l) { + case _ :: aParam :: bParam :: cParam :: Nil => + aParam.code shouldBe "a=1" + bParam.code shouldBe "*b" + cParam.code shouldBe "c" + case xs => fail(s"Expected 4 params, got ${xs.code.mkString(",")}") + } + } + + "Unnamed proc parameters" should { + val cpg = code(""" + |def outer_method(&) + | puts "In outer_method" + | inner_method(&) + |end + | + |def inner_method(&) + | puts "In inner_method" + | yield if block_given? + |end + | + |outer_method do + | puts "Hello from the block!" + |end + |""".stripMargin) + + "generate and reference proc param" in { + inside(cpg.method.name("outer_method").l) { + case outerMethod :: Nil => + val List(_, procParam) = outerMethod.parameter.l + procParam.name shouldBe "" + + inside(outerMethod.call.name("inner_method").argument.l) { + case _ :: procParamArg :: Nil => + procParamArg.code shouldBe "" + case xs => fail(s"Expected two arguments, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected one method def, got [${xs.name.mkString(",")}]") + } + } + + "call correct proc param in `yield`" in { + inside(cpg.method.name("inner_method").l) { + case innerMethod :: Nil => + val List(_, procParam) = innerMethod.parameter.l + procParam.name shouldBe "" + + innerMethod.call.nameExact("call").argument.isIdentifier.name.l shouldBe List("") + + case xs => fail(s"Expected one method def, got [${xs.name.mkString(",")}]") + } + } + } + + "lambdas as arguments to a long chained call" should { + val cpg = code(""" + |def foo(xs, total_ys, hex_values) + | xs.map.with_index { |f, i| [f / total_ys, hex_values[i]] } # 1 + | .sort_by { |r| -r[0] } # 2 + | .reject { |r| r[1].size == 8 && r[1].end_with?('00') } # 3 + | .map { |r| Foo::Bar::Baz.new(*r[1][0..5].scan(/../).map { |c| c.to_i(16) }) } # 4 & 5 + | .slice(0, quantity) + | end + |""".stripMargin) + + "not write lambda nodes that are already assigned to some temp variable" in { + cpg.typeRef.typeFullName(".*Proc").size shouldBe 5 + cpg.typeRef.whereNot(_.astParent).size shouldBe 0 + } + + "resolve cached lambdas correctly" in { + def getLineNumberOfLambdaForCall(callName: String) = + cpg.call.nameExact(callName).argument.isTypeRef.typ.referencedTypeDecl.lineNumber.head + + getLineNumberOfLambdaForCall("with_index") shouldBe 3 + getLineNumberOfLambdaForCall("sort_by") shouldBe 4 + getLineNumberOfLambdaForCall("reject") shouldBe 5 + getLineNumberOfLambdaForCall("map") shouldBe 6 + } + } + + "Forwarded args from method to call" should { + val cpg = code(""" + |def foo(...) + | bar('foo', ...) + |end + | + |""".stripMargin) + + "create a '...' parameter node" in { + inside(cpg.method.nameExact("foo").parameter.l) { case _ :: forwardArgs :: Nil => + forwardArgs.name shouldBe "..." + forwardArgs.code shouldBe "(...)" + } + } + + "create a '...' identifier node as a call argument" in { + inside(cpg.call("bar").argument.isIdentifier.l) { case _ :: forwardedArgs :: Nil => + forwardedArgs.name shouldBe "..." + forwardedArgs.code shouldBe "..." + forwardedArgs.argumentIndex shouldBe 2 + } + } + } + + "Implicit return of range expression" in { + val cpg = code(""" + |def size_range + | 1..MAX_FILE_SIZE + |end""".stripMargin) + + inside(cpg.method.name("size_range").methodReturn.toReturn.l) { + case rangeReturn :: Nil => + rangeReturn.code shouldBe "1..MAX_FILE_SIZE" + + val List(rangeOp) = rangeReturn.astChildren.isCall.l + rangeOp.methodFullName shouldBe Operators.range + + val List(lhs: Literal, rhs: Call) = rangeOp.argument.l: @unchecked + lhs.code shouldBe "1" + + rhs.code shouldBe "self.MAX_FILE_SIZE" + case xs => fail(s"Expected one return, got [${xs.code.mkString(",")}]") + } + } + + "Method call with same name as reserved keyword" in { + val cpg = code(""" + | def public + | list.sort_by(&:position).filter_map { |category| category.slug if category.visible_to_public? } + | end + | + | def notifiable + | public + | end + | + | def not_notifiable + | public + | puts 1 + | puts 2 + | end + |""".stripMargin) + + inside(cpg.method.name("notifiable").body.astChildren.isReturn.astChildren.isCall.name("public").l) { + case publicCall :: Nil => + publicCall.code shouldBe "public" + + val List(selfArg) = publicCall.argument.l + case xs => fail(s"Expected one call, got ${xs.code.mkString(",")}") + } + + inside(cpg.method.name("not_notifiable").body.astChildren.isCall.name("public").l) { + case publicCall :: Nil => + publicCall.code shouldBe "public" + + val List(selfArg) = publicCall.argument.l + case xs => fail(s"Expected one call, got ${xs.code.mkString(",")}") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ModuleTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ModuleTests.scala index 818e191e8b8d..cb7c5ce9281b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ModuleTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ModuleTests.scala @@ -1,7 +1,10 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.passes.Defines.Main +import io.joern.rubysrc2cpg.passes.{Defines, GlobalTypes} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.{File, Literal, NamespaceBlock} +import io.shiftleft.codepropertygraph.generated.{ModifierTypes, NodeTypes} import io.shiftleft.semanticcpg.language.* class ModuleTests extends RubyCode2CpgFixture { @@ -14,7 +17,7 @@ class ModuleTests extends RubyCode2CpgFixture { val List(m) = cpg.typeDecl.name("M").l - m.fullName shouldBe "Test0.rb:::program.M" + m.fullName shouldBe s"Test0.rb:$Main.M" m.lineNumber shouldBe Some(2) m.baseType.l shouldBe List() m.member.name.l shouldBe List(Defines.TypeDeclBody) @@ -31,11 +34,138 @@ class ModuleTests extends RubyCode2CpgFixture { val List(m) = cpg.typeDecl.name("M1").l - m.fullName shouldBe "Test0.rb:::program.M1" + m.fullName shouldBe s"Test0.rb:$Main.M1" m.lineNumber shouldBe Some(2) m.baseType.l shouldBe List() m.member.name.l shouldBe List(Defines.TypeDeclBody) m.method.name.l shouldBe List(Defines.TypeDeclBody) } + "Module defined in Namespace" in { + val cpg = code(""" + |module Api::V1::MobileController + |end + |""".stripMargin) + + inside(cpg.namespaceBlock.fullNameExact("Api.V1").typeDecl.l) { + case mobileNamespace :: mobileClassNamespace :: Nil => + mobileNamespace.name shouldBe "MobileController" + mobileNamespace.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + + mobileClassNamespace.name shouldBe "MobileController" + mobileClassNamespace.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + case xs => fail(s"Expected two namespace blocks, got ${xs.code.mkString(",")}") + } + + inside(cpg.typeDecl.name("MobileController").l) { + case mobileTypeDecl :: Nil => + mobileTypeDecl.name shouldBe "MobileController" + mobileTypeDecl.fullName shouldBe "Test0.rb:
.Api.V1.MobileController" + mobileTypeDecl.astParentFullName shouldBe "Api.V1" + mobileTypeDecl.astParentType shouldBe NodeTypes.NAMESPACE_BLOCK + + mobileTypeDecl.astParent.isNamespaceBlock shouldBe true + + val namespaceDecl = mobileTypeDecl.astParent.asInstanceOf[NamespaceBlock] + namespaceDecl.name shouldBe "Api.V1" + namespaceDecl.filename shouldBe "Test0.rb" + + namespaceDecl.astParent.isFile shouldBe true + val parentFileDecl = namespaceDecl.astParent.asInstanceOf[File] + parentFileDecl.name shouldBe "Test0.rb" + + case xs => fail(s"Expected one class decl, got [${xs.code.mkString(",")}]") + } + } + + "Class Method Modifiers" should { + val cpg = code(""" + |# Taken from Mastodon Repo + |module LanguagesHelper + | ISO_639_1 = {} + | ISO_639_3 = {} + | SUPPORTED_LOCALES = {} + | REGIONAL_LOCALE_NAMES = {} + | + | private_class_method def self.locale_name_for_sorting(locale) + | if (supported_locale = SUPPORTED_LOCALES[locale.to_sym]) + | ASCIIFolding.new.fold(supported_locale[1]).downcase + | elsif (regional_locale = REGIONAL_LOCALE_NAMES[locale.to_sym]) + | ASCIIFolding.new.fold(regional_locale).downcase + | else + | locale + | end + | end + | + | def publicMethodAfterwards + | end + |end + |""".stripMargin) + "Generate private modifier on method" in { + inside(cpg.method.name("locale_name_for_sorting")._modifierViaAstOut.l) { + case virtualModifier :: privateModifier :: Nil => + virtualModifier.modifierType shouldBe ModifierTypes.VIRTUAL + privateModifier.modifierType shouldBe ModifierTypes.PRIVATE + case xs => fail(s"Expected two modifiers, got [${xs.modifierType.mkString(",")}]") + } + } + + "Revert to original access modifier after previous method def" in { + inside(cpg.method.name("publicMethodAfterwards")._modifierViaAstOut.l) { + case virtualModifier :: publicModifier :: Nil => + virtualModifier.modifierType shouldBe ModifierTypes.VIRTUAL + publicModifier.modifierType shouldBe ModifierTypes.PUBLIC + case xs => fail(s"Expected got [${xs.modifierType.mkString(",")}]") + } + } + } + + "Protected call with block" should { + val cpg = code(""" + |module QA + | trait :protected do + | protected { true } + | end + |end + |""".stripMargin) + + "Have the correct proc arg in call" in { + inside(cpg.call.name("protected").argument.l) { + case _ :: proc :: Nil => + proc.code shouldBe "1&Proc" + case xs => fail(s"Expected one call for protected, got [${xs.code.mkString(",")}]") + } + } + + "Generate a lambda with true body" in { + inside(cpg.method.isLambda.l) { + case protectedLambda :: _ :: Nil => + val List(lambdaReturn) = protectedLambda.body.astChildren.isReturn.l + lambdaReturn.code shouldBe "true" + case xs => fail(s"Expected two lambdas, got [${xs.code.mkString(",")}]") + } + } + } + "Argument `(...)` in call should not be lifted" in { + val cpg = code(""" + |module ArticlesHelper + | def foo + | end + | + | def active_threads(...) + | Articles::ActiveThreadsQuery.call(...) + | end + |end + | + |""".stripMargin) + + inside(cpg.typeDecl.name("ArticlesHelper").method.l) { + case bodyMethod :: _ :: _ :: Nil => + inside(bodyMethod.block.astChildren.l) { + case Nil => // bodyMethod should be empty + case xs => fail(s"Expected empty body, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected three methods got [${xs.name.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala index 67abaca49a42..af3929e81308 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala @@ -1,84 +1,149 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines.RubyOperators import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators -import org.scalatest.Inspectors -import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* +import org.scalatest.Inspectors class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { - "Methods" should { - "with a yield expression" should { - "with a proc parameter" should { - val cpg1 = code("def foo(&b) yield end") - val cpg2 = code("def self.foo(&b) yield end") - val cpgs = List(cpg1, cpg2) - - "have a single block argument" in { - forAll(cpgs)(_.method("foo").parameter.code("&.*").name.l shouldBe List("b")) - } - "represent the yield as a conditional with a call and return node as children" in { - forAll(cpgs) { cpg => - inside(cpg.method("foo").call.nameExact(Operators.conditional).code("yield").astChildren.l) { - case List(cond: Expression, call: Call, ret: Return) => { - cond.code shouldBe "" - call.name shouldBe "b" - call.code shouldBe "b" - ret.code shouldBe "yield" - } - } - } - } - } - - "without a proc parameter" should { - val cpg = code(""" - |def foo() yield end - |def self.bar() yield end - |""".stripMargin) - - "have a call to a block parameter" in { - cpg.method("foo").call.code("yield").astChildren.isCall.code("").name.l shouldBe List( - "" - ) - cpg.method("bar").call.code("yield").astChildren.isCall.code("").name.l shouldBe List( - "" - ) - } + "a method with an explicit proc parameter should create an invocation of it's `call` member" in { + val cpg = code("def foo(&b) yield end") - "add a block argument" in { - val List(param1) = cpg.method("foo").parameter.code("&.*").l - param1.name shouldBe "" - param1.index shouldBe 1 + val foo = cpg.method("foo").head - val List(param2) = cpg.method("bar").parameter.code("&.*").l - param2.name shouldBe "" - param2.index shouldBe 1 - } - } - - "with yield arguments" should { - val cpg = code("def foo(x) yield(x) end") - "replace the yield with a call to the block parameter with arguments" in { - val List(call) = cpg.call.codeExact("yield(x)").astChildren.isCall.codeExact("").l - call.name shouldBe "" - call.argument.code.l shouldBe List("self", "x") - } + val bParam = foo.parameter.last + bParam.name shouldBe "b" + bParam.code shouldBe "&b" + bParam.index shouldBe 1 + + inside(foo.call.nameExact("call").argument.l) { case selfBase :: Nil => + selfBase.code shouldBe "b" + } + } + + "a singleton method with an explicit proc parameter should create an invocation of it's `call` member" in { + val cpg = code("def self.foo(&b) yield end") + + val foo = cpg.method("foo").head + + val bParam = foo.parameter.last + bParam.name shouldBe "b" + bParam.code shouldBe "&b" + bParam.index shouldBe 1 - } + inside(foo.call.nameExact("call").argument.l) { case selfBase :: Nil => + selfBase.code shouldBe "b" } + } + + "a method with an implicit proc parameter should create an invocation using a unique parameter name" in { + val cpg = code(""" + |def foo() yield end + |def self.bar() yield end + |""".stripMargin) + + val foo = cpg.method("foo").head + val bar = cpg.method("bar").head + + val fooParam = foo.parameter.last + fooParam.name shouldBe "" + fooParam.code shouldBe "&" + fooParam.index shouldBe 1 + + val barParam = bar.parameter.last + barParam.name shouldBe "" + barParam.code shouldBe "&" + barParam.index shouldBe 1 - "that don't have a yield nor a proc parameter" should { - val cpg1 = code("def foo() end") - val cpg2 = code("def self.foo() end") - val cpgs = List(cpg1, cpg2) + foo.call.nameExact("call").argument.isIdentifier.name.l shouldBe List("") + bar.call.nameExact("call").argument.isIdentifier.name.l shouldBe List("") + } + + "a method with an implicit proc parameter should create an invocation of it's `call` member with given arguments" in { + val cpg = code("def foo(x) yield(x) end") + + val foo = cpg.method("foo").head + + val List(xParam, procParam) = foo.parameter.l.takeRight(2) + + xParam.name shouldBe "x" + xParam.index shouldBe 1 - "not add a block argument" in { - forAll(cpgs)(_.method("foo").parameter.code("&.*").name.l should be(empty)) - } + procParam.name shouldBe "" + procParam.code shouldBe "&" + procParam.index shouldBe 2 + + inside(foo.call.nameExact("call").argument.l) { case selfBase :: x :: Nil => + selfBase.code shouldBe "" + selfBase.argumentIndex shouldBe 0 + x.code shouldBe "x" + x.argumentIndex shouldBe 1 } + } + + "a method without a yield nor proc parameter should not have either modelled" in { + val cpg1 = code("def foo() end") + val cpg2 = code("def self.foo() end") + val cpgs = List(cpg1, cpg2) + + forAll(cpgs)(cpg => { + cpg.method("foo").parameter.code("&.*").name.l should be(empty) + cpg.method("foo").call.nameExact("call").name.l should be(empty) + }) + } + "A Yield statement with multiple arguments" in { + val cpg = code(""" + |def foo + | yield 1, :z => 2 + |end + |""".stripMargin) + + inside(cpg.method.name("foo").call.nameExact("call").l) { + case yieldCall :: Nil => + inside(yieldCall.argument.l) { + case (base: Identifier) :: (oneLiteral: Literal) :: (twoLiteral: Literal) :: Nil => + base.name shouldBe "" + base.code shouldBe "" + + oneLiteral.code shouldBe "1" + oneLiteral.argumentIndex shouldBe 1 + twoLiteral.code shouldBe "2" + twoLiteral.argumentName shouldBe Some("z") + case xs => fail(s"Expected two arguments for yieldCall, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected one call for yield, got ${xs.code.mkString(",")}") + } } + "Yield in initialize should create implicit proc parameter" in { + val cpg = code(""" + |class Payload + |def initialize + | yield(self) + |end + |end + |""".stripMargin) + + val initMethod = cpg.method.name("initialize").head + + inside(initMethod.parameter.l) { + case _ :: procParam :: Nil => + // This seems a bit strange, but the `` method is being processed first which generates a procParam + // for the `MethodScope` which is why the procParam for this ConstructorScope is [1] instead of [0] + procParam.name shouldBe "" + procParam.code shouldBe "&" + procParam.index shouldBe 1 + case xs => fail(s"Expected two arguments, got [${xs.code.mkString(",")}]") + } + + inside(initMethod.call.nameExact("call").argument.l) { case selfBase :: selfParam :: Nil => + selfBase.code shouldBe "" + selfBase.argumentIndex shouldBe 0 + selfParam.code shouldBe "self" + selfParam.argumentIndex shouldBe 1 + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RangeTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RangeTests.scala index 72db05379666..28633186b7cf 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RangeTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RangeTests.scala @@ -23,4 +23,31 @@ class RangeTests extends RubyCode2CpgFixture { upperBound.code shouldBe "1" } + "`0..` is represented by a `range` operator call with infinity upperbound" in { + val cpg = code("0..") + val List(range) = cpg.call(Operators.range).l + + range.methodFullName shouldBe Operators.range + range.code shouldBe "0.." + range.lineNumber shouldBe Some(1) + + val List(lowerBound, upperBound) = range.argument.l + + lowerBound.code shouldBe "0" + upperBound.code shouldBe "Float::INFINITY" + } + + "`..0` is represented by a `range` operator call with infinity lowerbound" in { + val cpg = code("..0") + val List(range) = cpg.call(Operators.range).l + + range.methodFullName shouldBe Operators.range + range.code shouldBe "..0" + range.lineNumber shouldBe Some(1) + + val List(lowerBound, upperBound) = range.argument.l + + lowerBound.code shouldBe "-Float::INFINITY" + upperBound.code shouldBe "0" + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RegexTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RegexTests.scala index 674a39f6f40d..63bb16c2f9c0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RegexTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/RegexTests.scala @@ -1,51 +1,130 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.Defines.RubyOperators -import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.Literal +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} +import io.shiftleft.codepropertygraph.generated.{Cpg, Operators} import io.shiftleft.semanticcpg.language.* -class RegexTests extends RubyCode2CpgFixture(withPostProcessing = true) { - "`'x' =~ y` is a member call `'x'.=~ /y/" in { - val cpg = code("""|'x' =~ /y/ - |0 - |""".stripMargin) - cpg.call(RubyOperators.regexpMatch).methodFullName.l shouldBe List( - s"$kernelPrefix.String:${RubyOperators.regexpMatch}" - ) - } - "`/x/ =~ 'y'` is a member call `/x/.=~ 'y'" in { - val cpg = code("""|/x/ =~ 'y' - |0 - |""".stripMargin) - cpg.call(RubyOperators.regexpMatch).methodFullName.l shouldBe List( - s"$kernelPrefix.Regexp:${RubyOperators.regexpMatch}" - ) - } +class RegexTests extends RubyCode2CpgFixture(withPostProcessing = false) { + + "Global regex related variables" should { + + /** Checks for the presence of the lowered regex match which assigns the match results to the respective global + * variables. + * + * TODO: Check for matching of match group ($1, $2, etc.) variables. + */ + def assertLoweredStructure(cpg: Cpg, tmpNo: String = "0", expectedSubject: String = "\"hello\""): Unit = { + // We lower =~ to the `match` equivalent + val tmpInit = cpg.assignment.code(s" =.*").head + + val tmpTarget = tmpInit.target.asInstanceOf[Identifier] + tmpTarget.name shouldBe s"" + val tmpSource = tmpInit.source.asInstanceOf[Call] + tmpSource.code shouldBe s"/h(el)lo/.match($expectedSubject)" + tmpSource.name shouldBe "match" + tmpSource.methodFullName shouldBe "__core.Regexp.match" + + // Now test for the lowered global variable assignments + val ifStmt = cpg.controlStructure.last + inside(ifStmt.whenTrue.assignment.l) { case tildeAsgn :: amperAsgn :: match1Asgn :: Nil => + tildeAsgn.code shouldBe s"$$~ = " + val taSource = tildeAsgn.source.asInstanceOf[Identifier] + taSource.name shouldBe s"" + val taTarget = tildeAsgn.target.asInstanceOf[Call] + taTarget.methodFullName shouldBe Operators.fieldAccess + taTarget.code shouldBe "self.$~" + + amperAsgn.code shouldBe s"$$& = [0]" + val aaSource = amperAsgn.source.asInstanceOf[Call] + aaSource.methodFullName shouldBe Operators.indexAccess + aaSource.code shouldBe s"[0]" + aaSource.argument(1).asInstanceOf[Identifier].name shouldBe s"" + aaSource.argument(2).asInstanceOf[Literal].code shouldBe "0" + + val aaTarget = amperAsgn.target.asInstanceOf[Call] + aaTarget.methodFullName shouldBe Operators.fieldAccess + aaTarget.code shouldBe "self.$&" + + match1Asgn.code shouldBe s"$$1 = [1]" + val match1AsgnSource = match1Asgn.source.asInstanceOf[Call] + match1AsgnSource.methodFullName shouldBe Operators.indexAccess + match1AsgnSource.code shouldBe s"[1]" + + val match1AsgnTarget = match1Asgn.target.asInstanceOf[Call] + match1AsgnTarget.methodFullName shouldBe Operators.indexAccess + match1AsgnTarget.code shouldBe "$[1]" + } + inside(ifStmt.whenFalse.assignment.l) { case tildeAsgn :: amperAsgn :: Nil => + tildeAsgn.code shouldBe "$~ = nil" + val taSource = tildeAsgn.source.asInstanceOf[Literal] + taSource.code shouldBe "nil" + val taTarget = tildeAsgn.target.asInstanceOf[Call] + taTarget.methodFullName shouldBe Operators.fieldAccess + taTarget.code shouldBe "self.$~" + + amperAsgn.code shouldBe "$& = nil" + val aaSource = amperAsgn.source.asInstanceOf[Literal] + aaSource.code shouldBe "nil" + + val aaTarget = amperAsgn.target.asInstanceOf[Call] + aaTarget.methodFullName shouldBe Operators.fieldAccess + aaTarget.code shouldBe "self.$&" + } + } + + "be assigned to the match by the `~=` operator" in { - "Regex expression in if statements" in { - val cpg = code(""" - | - |if /mswin|mingw|cygwin/ =~ "mswin" - |end - |""".stripMargin) + val cpg = code(""" + |"hello" =~ /h(el)lo/ + |""".stripMargin) - inside(cpg.controlStructure.isIf.l) { - case regexIf :: Nil => - regexIf.condition.isCall.methodFullName.l shouldBe List(s"$kernelPrefix.Regexp:${RubyOperators.regexpMatch}") + assertLoweredStructure(cpg) + } + + "be assigned to the match in a case equality" in { + val cpg = code(""" + |case "hello" + |when /h(el)lo/ + | puts $1 + |end + |""".stripMargin) + + assertLoweredStructure(cpg, "1", "") + } + + "be assigned to the match in a match call (regex lhs)" in { + val cpg = code(""" + |/h(el)lo/.match("hello") + |""".stripMargin) + + assertLoweredStructure(cpg) + } - inside(regexIf.condition.isCall.argument.l) { - case (lhs: Literal) :: (rhs: Literal) :: Nil => - lhs.code shouldBe "/mswin|mingw|cygwin/" - lhs.typeFullName shouldBe s"$kernelPrefix.Regexp" + "be assigned to the match in a match call (regex rhs)" in { + val cpg = code(""" + |"hello".match(/h(el)lo/) + |""".stripMargin) - rhs.code shouldBe "\"mswin\"" - rhs.typeFullName shouldBe s"$kernelPrefix.String" - case xs => fail(s"Expected two arguments, got [${xs.code.mkString(",")}]") - } + assertLoweredStructure(cpg) + } + + "be assigned to the match using string indexing" in { + val cpg = code(""" + |"hello"[/h(el)lo/] + |""".stripMargin) - case xs => fail(s"One if statement expected, got [${xs.code.mkString(",")}]") + assertLoweredStructure(cpg) } + + "be assigned to the match using `sub` (or `gsub`) calls" in { + val cpg = code(""" + |"hello".sub(/h(el)lo/) + |""".stripMargin) + + assertLoweredStructure(cpg) + } + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SetterTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SetterTests.scala deleted file mode 100644 index 7381411226bb..000000000000 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SetterTests.scala +++ /dev/null @@ -1,29 +0,0 @@ -package io.joern.rubysrc2cpg.querying - -import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.semanticcpg.language.* - -class SetterTests extends RubyCode2CpgFixture { - - "`x.y=1` is represented by a `x.y=` CALL with argument `1`" in { - val cpg = code("""x = Foo.new - |x.y = 1 - |""".stripMargin) - - val List(setter) = cpg.call("y=").l - val List(fieldAccess) = cpg.fieldAccess.l - - setter.code shouldBe "x.y = 1" - setter.lineNumber shouldBe Some(2) - setter.receiver.l shouldBe List(fieldAccess) - - fieldAccess.code shouldBe "x.y=" - fieldAccess.lineNumber shouldBe Some(2) - fieldAccess.fieldIdentifier.code.l shouldBe List("y=") - - val List(_, one) = setter.argument.l - one.code shouldBe "1" - one.lineNumber shouldBe Some(2) - } - -} diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala index ff2dd238f5b2..4283cb51c259 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala @@ -1,8 +1,9 @@ package io.joern.rubysrc2cpg.querying +import io.joern.rubysrc2cpg.passes.Defines as RubyDefines import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} import io.shiftleft.semanticcpg.language.* class SingleAssignmentTests extends RubyCode2CpgFixture { @@ -42,34 +43,54 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { rhs.code shouldBe "1" } - "`||=` is represented by an `assignmentOr` operator call" in { + "`||=` is represented by a lowered if call to .nil?" in { val cpg = code(""" - |x ||= false + |def foo(x) + | x ||= false + |end |""".stripMargin) - val List(assignment) = cpg.call(Operators.assignmentOr).l - assignment.code shouldBe "x ||= false" - assignment.lineNumber shouldBe Some(2) - assignment.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("!x") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "x = false" + val List(lhs, rhs) = assignmentCall.argument.l + lhs.code shouldBe "x" + rhs.code shouldBe "false" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } - val List(lhs, rhs) = assignment.argument.l - lhs.code shouldBe "x" - rhs.code shouldBe "false" + case xs => fail(s"Expected one control structure, got ${xs.code.mkString(",")}") + } } - "`&&=` is represented by an `assignmentAnd` operator call" in { + "`&&=` is represented by lowered if call to .nil?" in { val cpg = code(""" - |x &&= true + |def foo(x) + | x &&= true + |end |""".stripMargin) - val List(assignment) = cpg.call(Operators.assignmentAnd).l - assignment.code shouldBe "x &&= true" - assignment.lineNumber shouldBe Some(2) - assignment.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("x") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "x = true" + val List(lhs, rhs) = assignmentCall.argument.l + lhs.code shouldBe "x" + rhs.code shouldBe "true" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } - val List(lhs, rhs) = assignment.argument.l - lhs.code shouldBe "x" - rhs.code shouldBe "true" + case xs => fail(s"Expected one control structure, got ${xs.code.mkString(",")}") + } } "`/=` is represented by an `assignmentDivision` operator call" in { @@ -167,7 +188,7 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { assign1.argument(2).code shouldBe "1" assign1.argument(2).lineNumber shouldBe Some(4) - assign2.lineNumber shouldBe Some(5) + assign2.lineNumber shouldBe Some(6) assign2.argument(1).code shouldBe "x" assign2.argument(2).code shouldBe "2" assign2.argument(2).lineNumber shouldBe Some(6) @@ -177,7 +198,7 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { assign3.argument(2).code shouldBe "3" assign3.argument(2).lineNumber shouldBe Some(10) - assign4.lineNumber shouldBe Some(11) + assign4.lineNumber shouldBe Some(12) assign4.argument(1).code shouldBe "x" assign4.argument(2).code shouldBe "4" assign4.argument(2).lineNumber shouldBe Some(12) @@ -233,4 +254,306 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { } } + "Bracket Assignments" in { + val cpg = code(""" + | def get_pto_schedule + | begin + | schedules = current_user.paid_time_off.schedule + | jfs = [] + | schedules.each do |s| + | hash = Hash.new + | hash[:id] = s[:id] + | hash[:title] = s[:event_name] + | hash[:start] = s[:date_begin] + | hash[:end] = s[:date_end] + | jfs << hash + | end + | rescue + | end + | respond_to do |format| + | format.json { render json: jfs.to_json } + | end + | end + |""".stripMargin) + + inside(cpg.method.isLambda.l) { + case scheduleLambda :: _ :: _ :: Nil => + inside(scheduleLambda.call.name(Operators.assignment).l) { + case _ :: id :: title :: start :: end :: _ :: Nil => + id.code shouldBe "hash[:id] = s[:id]" + + inside(id.argument.l) { + case (lhs: Call) :: (rhs: Call) :: Nil => + lhs.methodFullName shouldBe Operators.indexAccess + lhs.code shouldBe "hash[:id]" + + rhs.methodFullName shouldBe Operators.indexAccess + rhs.code shouldBe "s[:id]" + + inside(lhs.argument.l) { + case base :: (index: Literal) :: Nil => + index.typeFullName shouldBe RubyDefines.prefixAsCoreType(RubyDefines.Symbol) + case xs => fail(s"Expected base and index, got [${xs.code.mkString(",")}]") + } + + inside(rhs.argument.l) { + case base :: (index: Literal) :: Nil => + index.typeFullName shouldBe RubyDefines.prefixAsCoreType(RubyDefines.Symbol) + case xs => fail(s"Expected base and index, got [${xs.code.mkString(",")}]") + } + + case xs => fail(s"Expected lhs and rhs, got ${xs.code.mkString(";")}]") + } + case xs => fail(s"Expected six assignments, got [${xs.code.mkString(";")}]") + } + case xs => fail(s"Expected three lambdas, got ${xs.size} lambdas instead") + } + } + + "Bracketed ||= is represented by a lowered if call to .nil?" in { + val cpg = code(""" + |def foo + | hash[:id] ||= s[:id] + |end + |""".stripMargin) + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("!hash[:id]") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "hash[:id] = s[:id]" + val List(lhs, rhs) = assignmentCall.argument.l + lhs.code shouldBe "hash[:id]" + rhs.code shouldBe "s[:id]" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } + + case xs => fail(s"Expected one control structure, got ${xs.code.mkString(",")}") + } + } + + "Bracketed +=" in { + val cpg = code(""" + |hash[:id] += s[:id] + |""".stripMargin) + + inside(cpg.call.name(Operators.assignmentPlus).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "hash[:id] += s[:id]" + assignmentCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + inside(assignmentCall.argument.l) { + case lhs :: rhs :: Nil => + lhs.code shouldBe "hash[:id]" + rhs.code shouldBe "s[:id]" + case xs => fail(s"Expected lhs and rhs arguments, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected on assignmentOr call, got ${xs.code.mkString(",")}") + } + } + + "Bracketed &&= is represented by a lowere if call to .nil?" in { + val cpg = code(""" + |def foo + | hash[:id] &&= s[:id] + |end + |""".stripMargin) + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("hash[:id]") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "hash[:id] = s[:id]" + val List(lhs, rhs) = assignmentCall.argument.l + lhs.code shouldBe "hash[:id]" + rhs.code shouldBe "s[:id]" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } + + case xs => fail(s"Expected one control structure, got ${xs.code.mkString(",")}") + } + } + + "Bracketed /=" in { + val cpg = code(""" + |hash[:id] /= s[:id] + |""".stripMargin) + + inside(cpg.call.name(Operators.assignmentDivision).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "hash[:id] /= s[:id]" + assignmentCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + + inside(assignmentCall.argument.l) { + case lhs :: rhs :: Nil => + lhs.code shouldBe "hash[:id]" + rhs.code shouldBe "s[:id]" + case xs => fail(s"Expected lhs and rhs arguments, got ${xs.code.mkString(",")}") + } + case xs => fail(s"Expected on assignmentOr call, got ${xs.code.mkString(",")}") + } + } + + "Single ||= Assignment" in { + val cpg = code(""" + |def foo + | A.B ||= c 1 + |end + |""".stripMargin) + + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("!A.B") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "A.B = c 1" + val List(lhs, rhs: Call) = assignmentCall.argument.l: @unchecked + lhs.code shouldBe "A.B" + + rhs.code shouldBe "c 1" + val List(_, litArg) = rhs.argument.l + litArg.code shouldBe "1" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } + + case xs => fail(s"Expected one if statement, got ${xs.code.mkString(",")}") + } + } + + "Single &&= Assignment" in { + val cpg = code(""" + |def foo + | A.B &&= c 1 + |end + |""".stripMargin) + + inside(cpg.method.name("foo").controlStructure.l) { + case ifStruct :: Nil => + ifStruct.controlStructureType shouldBe ControlStructureTypes.IF + ifStruct.condition.code.l shouldBe List("A.B") + + inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "A.B = c 1" + val List(lhs: Call, rhs: Call) = assignmentCall.argument.l: @unchecked + lhs.code shouldBe "A.B" + lhs.methodFullName shouldBe Operators.fieldAccess + + rhs.code shouldBe "c 1" + val List(_, litArg) = rhs.argument.l + litArg.code shouldBe "1" + case xs => fail(s"Expected assignment call in true branch, got ${xs.code.mkString}") + } + + case xs => fail(s"Expected one if statement, got ${xs.code.mkString(",")}") + } + } + + "+= assignment operator" in { + val cpg = code(""" + |A::b += 1 + |""".stripMargin) + + inside(cpg.call.name(Operators.assignmentPlus).l) { + case assignmentCall :: Nil => + val List(lhs: Call, rhs) = assignmentCall.argument.l: @unchecked + + lhs.code shouldBe "A::b" + lhs.methodFullName shouldBe Operators.fieldAccess + + rhs.code shouldBe "1" + case xs => fail(s"Expected one call for assignment, got ${xs.code.mkString(",")}") + } + } + + "*= assignment operator" in { + val cpg = code(""" + |A::b *= 1 + |""".stripMargin) + + inside(cpg.call.name(Operators.assignmentMultiplication).l) { + case assignmentCall :: Nil => + assignmentCall.code shouldBe "A::b *= 1" + val List(lhs: Call, rhs) = assignmentCall.argument.l: @unchecked + + lhs.code shouldBe "A::b" + lhs.methodFullName shouldBe Operators.fieldAccess + + rhs.code shouldBe "1" + case xs => fail(s"Expected one call for assignment, got ${xs.code.mkString(",")}") + } + } + + "MethodInvocationWithoutParentheses multiple call args" in { + val cpg = code(""" + |def gl_badge_tag(*args, &block) + | render :some_symbol, &block + |end + |""".stripMargin) + + inside(cpg.call.name("render").argument.l) { + case _ :: (symbolArg: Literal) :: (blockArg: Identifier) :: Nil => + symbolArg.code shouldBe ":some_symbol" + blockArg.code shouldBe "block" + + case xs => fail(s"Expected two args, found [${xs.code.mkString(",")}]") + } + } + + "bitwise AND/OR assignments should parse correctly" in { + val cpg = code(""" + |x = 1 + |x &= 0 + |x |= 1 + |""".stripMargin) + + inside(cpg.assignment.l) { case _ :: and :: or :: Nil => + and.name shouldBe Operators.assignmentAnd + and.code shouldBe "x &= 0" + + or.name shouldBe Operators.assignmentOr + or.code shouldBe "x |= 1" + } + } + + "shift left/right assignments should parse correctly" in { + val cpg = code(""" + |x = 1 + |x >>= 1 + |x <<= 2 + |""".stripMargin) + + inside(cpg.assignment.l) { case _ :: sr :: sl :: Nil => + sr.name shouldBe Operators.assignmentArithmeticShiftRight + sr.code shouldBe "x >>= 1" + + sl.name shouldBe Operators.assignmentShiftLeft + sl.code shouldBe "x <<= 2" + } + } + + "global variable assignment" in { + val cpg = code(""" + |$alfred = "123" + |""".stripMargin) + + inside(cpg.call.name(Operators.assignment).l) { + case alfredAssign :: Nil => + alfredAssign.code shouldBe "$alfred = \"123\"" + + val List(lhs: Call, rhs: Literal) = alfredAssign.argument.l: @unchecked + + lhs.methodFullName shouldBe Operators.fieldAccess + lhs.code shouldBe "self.$alfred" + + rhs.code shouldBe "\"123\"" + case xs => fail(s"Expected one assignment call, got [${xs.code.mkString(",")}]") + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/testfixtures/RubyCode2CpgFixture.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/testfixtures/RubyCode2CpgFixture.scala index c8ec8791bd1d..e4d182fd7732 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/testfixtures/RubyCode2CpgFixture.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/testfixtures/RubyCode2CpgFixture.scala @@ -1,29 +1,29 @@ package io.joern.rubysrc2cpg.testfixtures +import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.language.Path -import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.dataflowengineoss.testfixtures.{SemanticCpgTestFixture, SemanticTestCpg} -import io.joern.rubysrc2cpg.deprecated.utils.PackageTable import io.joern.rubysrc2cpg.{Config, RubySrc2Cpg} +import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.testfixtures.* -import io.joern.x2cpg.{ValidationMode, X2Cpg} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} -import org.scalatest.Tag +import org.scalatest.Inside import java.io.File -import org.scalatest.Inside +import java.nio.file.Files +import scala.jdk.CollectionConverters.* -trait RubyFrontend(useDeprecatedFrontend: Boolean, withDownloadDependencies: Boolean) extends LanguageFrontend { +trait RubyFrontend(withDownloadDependencies: Boolean, disableFileContent: Boolean) extends LanguageFrontend { override val fileSuffix: String = ".rb" implicit val config: Config = getConfig() .map(_.asInstanceOf[Config]) .getOrElse(Config().withSchemaValidation(ValidationMode.Enabled)) - .withUseDeprecatedFrontend(useDeprecatedFrontend) .withDownloadDependencies(withDownloadDependencies) + .withDisableFileContent(disableFileContent) override def execute(sourceCodeFile: File): Cpg = { new RubySrc2Cpg().createCpg(sourceCodeFile.getAbsolutePath).get @@ -31,12 +31,9 @@ trait RubyFrontend(useDeprecatedFrontend: Boolean, withDownloadDependencies: Boo } -class DefaultTestCpgWithRuby( - packageTable: Option[PackageTable], - useDeprecatedFrontend: Boolean, - downloadDependencies: Boolean = false -) extends DefaultTestCpg - with RubyFrontend(useDeprecatedFrontend, downloadDependencies) +class DefaultTestCpgWithRuby(downloadDependencies: Boolean = false, disableFileContent: Boolean = true) + extends DefaultTestCpg + with RubyFrontend(downloadDependencies, disableFileContent) with SemanticTestCpg { override protected def applyPasses(): Unit = { @@ -45,31 +42,24 @@ class DefaultTestCpgWithRuby( } override protected def applyPostProcessingPasses(): Unit = { - packageTable match { - case Some(table) => - RubySrc2Cpg.packageTableInfo.set(table) - case None => - } RubySrc2Cpg.postProcessingPasses(this, config).foreach(_.createAndApply()) } - } class RubyCode2CpgFixture( withPostProcessing: Boolean = false, withDataFlow: Boolean = false, downloadDependencies: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, - packageTable: Option[PackageTable] = None, - useDeprecatedFrontend: Boolean = false + disableFileContent: Boolean = true, + semantics: Semantics = DefaultSemantics() ) extends Code2CpgFixture(() => - new DefaultTestCpgWithRuby(packageTable, useDeprecatedFrontend, downloadDependencies) + new DefaultTestCpgWithRuby(downloadDependencies, disableFileContent) .withOssDataflow(withDataFlow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) with Inside - with SemanticCpgTestFixture(extraFlows) { + with SemanticCpgTestFixture(semantics) { implicit val resolver: ICallResolver = NoResolve @@ -79,17 +69,9 @@ class RubyCode2CpgFixture( } } -class RubyCfgTestCpg(useDeprecatedFrontend: Boolean = true, downloadDependencies: Boolean = false) +class RubyCfgTestCpg(downloadDependencies: Boolean = false, disableFileContent: Boolean = true) extends CfgTestCpg - with RubyFrontend(useDeprecatedFrontend, downloadDependencies) { + with RubyFrontend(downloadDependencies, disableFileContent) { override val fileSuffix: String = ".rb" } - -/** Denotes a test which has been similarly ported to the new frontend. - */ -object SameInNewFrontend extends Tag("SameInNewFrontend") - -/** Denotes a test which has been ported to the new frontend, but has different expectations. - */ -object DifferentInNewFrontend extends Tag("DifferentInNewFrontend") diff --git a/joern-cli/frontends/swiftsrc2cpg/build.sbt b/joern-cli/frontends/swiftsrc2cpg/build.sbt index e9c4c39045b9..2f62a3e23165 100644 --- a/joern-cli/frontends/swiftsrc2cpg/build.sbt +++ b/joern-cli/frontends/swiftsrc2cpg/build.sbt @@ -68,16 +68,15 @@ astGenDlTask := { val astGenDir = baseDirectory.value / "bin" / "astgen" astGenBinaryNames.value.foreach { fileName => - DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", astGenDir / fileName) + val file = astGenDir / fileName + DownloadHelper.ensureIsAvailable(s"${astGenDlUrl.value}$fileName", file) + // permissions are lost during the download; need to set them manually + file.setExecutable(true, false) } val distDir = (Universal / stagingDirectory).value / "bin" / "astgen" distDir.mkdirs() - IO.copyDirectory(astGenDir, distDir) - - // permissions are lost during the download; need to set them manually - astGenDir.listFiles().foreach(_.setExecutable(true, false)) - distDir.listFiles().foreach(_.setExecutable(true, false)) + IO.copyDirectory(astGenDir, distDir, preserveExecutable = true) } Compile / compile := ((Compile / compile) dependsOn astGenDlTask).value @@ -92,3 +91,7 @@ stage := Def Universal / packageName := name.value Universal / topLevelDirectory := None + +/** write the astgen version to the manifest for downstream usage */ +Compile / packageBin / packageOptions += + Package.ManifestAttributes(new java.util.jar.Attributes.Name("Swift-AstGen-Version") -> astGenVersion.value) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/Main.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/Main.scala index b9d071df2273..2d2d37dd6257 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/Main.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/Main.scala @@ -4,6 +4,7 @@ import io.joern.swiftsrc2cpg.Frontend.* import io.joern.x2cpg.passes.frontend.{TypeRecoveryParserConfig, XTypeRecovery, XTypeRecoveryConfig} import io.joern.x2cpg.utils.Environment import io.joern.x2cpg.{X2CpgConfig, X2CpgMain} +import io.joern.x2cpg.utils.server.FrontendHTTPServer import scopt.OParser import java.nio.file.Paths @@ -34,14 +35,19 @@ object Frontend { } -object Main extends X2CpgMain(cmdLineParser, new SwiftSrc2Cpg()) { +object Main extends X2CpgMain(cmdLineParser, new SwiftSrc2Cpg()) with FrontendHTTPServer[Config, SwiftSrc2Cpg] { + + override protected def newDefaultConfig(): Config = Config() def run(config: Config, swiftsrc2cpg: SwiftSrc2Cpg): Unit = { - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - if (Environment.pathExists(absPath)) { - swiftsrc2cpg.run(config.withInputPath(absPath)) - } else { - System.exit(1) + if (config.serverMode) { startup() } + else { + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + if (Environment.pathExists(absPath)) { + swiftsrc2cpg.run(config.withInputPath(absPath)) + } else { + System.exit(1) + } } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala index 43d16968e363..af843e9d2568 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala @@ -21,7 +21,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewTypeRef import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.File.PropertyDefaults import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import scala.collection.mutable @@ -121,10 +121,10 @@ class AstCreator(val config: Config, val global: Global, val parserResult: Parse case null => notHandledYet(node) } - override protected def line(node: SwiftNode): Option[Int] = node.startLine.map(Integer.valueOf) - override protected def column(node: SwiftNode): Option[Int] = node.startColumn.map(Integer.valueOf) - override protected def lineEnd(node: SwiftNode): Option[Int] = node.endLine.map(Integer.valueOf) - override protected def columnEnd(node: SwiftNode): Option[Int] = node.endColumn.map(Integer.valueOf) + override protected def line(node: SwiftNode): Option[Int] = node.startLine + override protected def column(node: SwiftNode): Option[Int] = node.startColumn + override protected def lineEnd(node: SwiftNode): Option[Int] = node.endLine + override protected def columnEnd(node: SwiftNode): Option[Int] = node.endColumn private val lineOffsetTable = OffsetUtils.getLineOffsetTable(Option(parserResult.fileContent)) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index 20b47f96f7bb..7a60e48fe4d2 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -10,6 +10,7 @@ import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.GuardStmtSyntax import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.InitializerDeclSyntax import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.SwiftNode import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines +import io.joern.x2cpg.utils.IntervalKeyPool import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.utils.NodeBuilders.{newClosureBindingNode, newLocalNode} import io.shiftleft.codepropertygraph.generated.nodes.NewNode @@ -18,7 +19,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewNamespaceBlock import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.PropertyNames -import io.shiftleft.passes.IntervalKeyPool import scala.collection.mutable diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/ImportsPass.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/ImportsPass.scala index 963e5e4e66c4..d89d9b8f5dbe 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/ImportsPass.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/ImportsPass.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.X2Cpg import io.joern.x2cpg.passes.frontend.XImportsPass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment /** This pass creates `IMPORT` nodes by looking for calls to `require`. `IMPORT` nodes are linked to existing dependency diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala index 75a5d0348985..1e78013d2b25 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala @@ -3,15 +3,14 @@ package io.joern.swiftsrc2cpg.passes import io.shiftleft.codepropertygraph.generated.Cpg import io.joern.x2cpg.passes.frontend.TypeNodePass import io.shiftleft.semanticcpg.language.* -import io.shiftleft.passes.KeyPool import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import scala.collection.mutable object SwiftTypeNodePass { - def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) { + def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg): TypeNodePass = { + new TypeNodePass(registeredTypes, cpg, getTypesFromCpg = false) { override def fullToShortName(typeName: String): String = { typeName match { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/AstGenRunner.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/AstGenRunner.scala index 1b2ef98e9a94..88b07c32917d 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/AstGenRunner.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/AstGenRunner.scala @@ -65,7 +65,7 @@ object AstGenRunner { val astGenCommand = path.getOrElse("SwiftAstGen") val localPath = path.flatMap(File(_).parentOption.map(_.pathAsString)).getOrElse(".") val debugMsgPath = path.getOrElse("PATH") - ExternalCommand.run(s"$astGenCommand -h", localPath).toOption match { + ExternalCommand.run(Seq(astGenCommand, "-h"), localPath).toOption match { case Some(_) => logger.debug(s"Using SwiftAstGen from $debugMsgPath") true @@ -140,7 +140,7 @@ class AstGenRunner(config: Config) { } private def runAstGenNative(in: File, out: File): Try[Seq[String]] = - ExternalCommand.run(s"$astGenCommand -o $out", in.toString()) + ExternalCommand.run(Seq(astGenCommand, "-o", out.toString), in.toString()) private def checkParsedFiles(files: List[String], in: File): List[String] = { val numOfParsedFiles = files.size diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/ExternalCommand.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/ExternalCommand.scala index b2ba122f8a5d..84ff1422e78c 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/ExternalCommand.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/ExternalCommand.scala @@ -4,19 +4,20 @@ import scala.util.Failure import scala.util.Success import scala.util.Try -object ExternalCommand extends io.joern.x2cpg.utils.ExternalCommand { +object ExternalCommand { - override def handleRunResult(result: Try[Int], stdOut: Seq[String], stdErr: Seq[String]): Try[Seq[String]] = { - result match { - case Success(0) => + import io.joern.x2cpg.utils.ExternalCommand.ExternalCommandResult + + def run(command: Seq[String], cwd: String, extraEnv: Map[String, String] = Map.empty): Try[Seq[String]] = { + io.joern.x2cpg.utils.ExternalCommand.run(command, cwd, mergeStdErrInStdOut = true, extraEnv) match { + case ExternalCommandResult(0, stdOut, _) => Success(stdOut) - case Success(_) if stdErr.isEmpty && stdOut.nonEmpty => + case ExternalCommandResult(_, stdOut, stdErr) if stdErr.isEmpty && stdOut.nonEmpty => // SwiftAstGen exits with exit code != 0 on Windows. // To catch with we specifically handle the empty stdErr here. Success(stdOut) - case _ => - val allOutput = stdOut ++ stdErr - Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) + case ExternalCommandResult(_, stdOut, _) => + Failure(new RuntimeException(stdOut.mkString(System.lineSeparator()))) } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/dataflow/DataFlowTests.scala index e7de93beec41..15ac7441b461 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/dataflow/DataFlowTests.scala @@ -8,7 +8,6 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.Identifier import io.shiftleft.codepropertygraph.generated.nodes.Literal import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.toNodeTraversal class DataFlowTests extends DataFlowCodeToCpgSuite { @@ -871,7 +870,7 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { cpg .call("bar") .outE(EdgeTypes.REACHING_DEF) - .count(_.inNode() == cpg.ret.head) shouldBe 1 + .count(_.dst == cpg.ret.head) shouldBe 1 } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/CodeDumperFromFileTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/CodeDumperFromFileTests.scala index 401379610cca..74070859371c 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/CodeDumperFromFileTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/CodeDumperFromFileTests.scala @@ -3,7 +3,7 @@ package io.joern.swiftsrc2cpg.io import better.files.File import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.semanticcpg.codedumper.CodeDumper -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/SwiftSrc2CpgHTTPServerTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/SwiftSrc2CpgHTTPServerTests.scala new file mode 100644 index 000000000000..fccf57f5b5e8 --- /dev/null +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/io/SwiftSrc2CpgHTTPServerTests.scala @@ -0,0 +1,82 @@ +package io.joern.swiftsrc2cpg.io + +import better.files.File +import io.joern.x2cpg.utils.server.FrontendHTTPClient +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader +import io.shiftleft.semanticcpg.language.* +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.collection.parallel.CollectionConverters.RangeIsParallelizable +import scala.util.Failure +import scala.util.Success + +class SwiftSrc2CpgHTTPServerTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + private var port: Int = -1 + + private def newProjectUnderTest(index: Option[Int] = None): File = { + val dir = File.newTemporaryDirectory("swiftsrc2cpgTestsHttpTest") + val file = dir / "main.swift" + file.createIfNotExists(createParents = true) + val indexStr = index.map(_.toString).getOrElse("") + file.writeText(s""" + |func main() { + | println($indexStr) + |}""".stripMargin) + file.deleteOnExit() + dir.deleteOnExit() + } + + override def beforeAll(): Unit = { + // Start server + port = io.joern.swiftsrc2cpg.Main.startup() + } + + override def afterAll(): Unit = { + // Stop server + io.joern.swiftsrc2cpg.Main.stop() + } + + "Using swiftsrc2cpg in server mode" should { + "build CPGs correctly (single test)" in { + val cpgOutFile = File.newTemporaryFile("swiftsrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest() + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain("println()") + } + } + + "build CPGs correctly (multi-threaded test)" in { + (0 until 10).par.foreach { index => + val cpgOutFile = File.newTemporaryFile("swiftsrc2cpg.bin") + cpgOutFile.deleteOnExit() + val projectUnderTest = newProjectUnderTest(Some(index)) + val input = projectUnderTest.path.toAbsolutePath.toString + val output = cpgOutFile.toString + val client = FrontendHTTPClient(port) + val req = client.buildRequest(Array(s"input=$input", s"output=$output")) + client.sendRequest(req) match { + case Failure(exception) => fail(exception.getMessage) + case Success(out) => + out shouldBe output + val cpg = CpgLoader.load(output) + cpg.method.name.l should contain("main") + cpg.call.code.l should contain(s"println($index)") + } + } + } + } + +} diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ActorTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ActorTests.scala index 9f35b700b653..ac7a7b581e3a 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ActorTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ActorTests.scala @@ -3,9 +3,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class ActorTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index bc5471c202f6..f49f1b07708d 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -3,9 +3,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class AvailabilityQueryTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BorrowExprTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BorrowExprTests.scala index 346292310e0b..41d5bfb74245 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BorrowExprTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BorrowExprTests.scala @@ -3,7 +3,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class BorrowExprTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BuiltinWordTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BuiltinWordTests.scala index 1619e9046cf2..01e2535d6108 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BuiltinWordTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/BuiltinWordTests.scala @@ -3,7 +3,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class BuiltinWordTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ConflictMarkersTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ConflictMarkersTests.scala index 83c8113a0304..1c102a80e460 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ConflictMarkersTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ConflictMarkersTests.scala @@ -3,7 +3,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class ConflictMarkersTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/CopyExprTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/CopyExprTests.scala index a37615b6b7c0..74cdd5dd2ad3 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/CopyExprTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/CopyExprTests.scala @@ -3,7 +3,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CopyExprTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/DeclarationTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/DeclarationTests.scala index d296aa4a8c6e..07ba4e56a134 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/DeclarationTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/DeclarationTests.scala @@ -2,9 +2,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class DeclarationTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/EnumTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/EnumTests.scala index 4c06817fe338..eb1f77e8ac99 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/EnumTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/EnumTests.scala @@ -4,9 +4,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class EnumTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ExpressionTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ExpressionTests.scala index a34be3b027d8..fe36cd54b82f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ExpressionTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ExpressionTests.scala @@ -2,9 +2,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class ExpressionTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ForeachTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ForeachTests.scala index 034f2cd05a6b..82baf8f5365f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ForeachTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ForeachTests.scala @@ -2,9 +2,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class ForeachTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala index 071f028346f7..99102b6f6b62 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala @@ -2,9 +2,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class StatementTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SuperTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SuperTests.scala index 3991c1d143f5..c3e538e433a3 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SuperTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SuperTests.scala @@ -4,7 +4,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class SuperTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SwitchTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SwitchTests.scala index a49ca2a5a8b5..1532d02ad0da 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SwitchTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/SwitchTests.scala @@ -4,9 +4,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class SwitchTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ToplevelLibraryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ToplevelLibraryTests.scala index eda0969c3d9c..a70e3645c9ba 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ToplevelLibraryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ToplevelLibraryTests.scala @@ -4,9 +4,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class ToplevelLibraryTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TryTests.scala index cc04f6a56439..92f9e092680e 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TryTests.scala @@ -4,9 +4,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class TryTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TypealiasTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TypealiasTests.scala index 97ecb3006bdc..9a974501a878 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TypealiasTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/TypealiasTests.scala @@ -4,9 +4,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class TypealiasTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala index df30834877dd..6138b719d8e1 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala @@ -2,9 +2,9 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.AstSwiftSrc2CpgSuite -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* class WhileTests extends AstSwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala index 451f2dfefeb0..af0875c95404 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/DataFlowCodeToCpgSuite.scala @@ -27,7 +27,7 @@ class DataFlowCodeToCpgSuite extends Code2CpgFixture(() => new DataFlowTestCpg() protected implicit val context: EngineContext = EngineContext() protected def flowToResultPairs(path: Path): List[(String, Integer)] = - path.resultPairs().collect { case (firstElement: String, secondElement: Option[Integer]) => + path.resultPairs().collect { case (firstElement: String, secondElement) => (firstElement, secondElement.getOrElse(-1)) } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/SwiftSrc2CpgSuite.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/SwiftSrc2CpgSuite.scala index 1eeebaffd7d5..2aa8cc29b730 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/SwiftSrc2CpgSuite.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/testfixtures/SwiftSrc2CpgSuite.scala @@ -1,16 +1,17 @@ package io.joern.swiftsrc2cpg.testfixtures -import io.joern.dataflowengineoss.semanticsloader.FlowSemantic +import io.joern.dataflowengineoss.DefaultSemantics +import io.joern.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.joern.x2cpg.testfixtures.Code2CpgFixture class SwiftSrc2CpgSuite( fileSuffix: String = ".swift", withOssDataflow: Boolean = false, - extraFlows: List[FlowSemantic] = List.empty, + semantics: Semantics = DefaultSemantics(), withPostProcessing: Boolean = false ) extends Code2CpgFixture(() => new SwiftDefaultTestCpg(fileSuffix) .withOssDataflow(withOssDataflow) - .withExtraFlows(extraFlows) + .withSemantics(semantics) .withPostProcessingPasses(withPostProcessing) ) diff --git a/joern-cli/frontends/x2cpg/build.sbt b/joern-cli/frontends/x2cpg/build.sbt index dcf0bde3b5cc..4aa7fbfb1376 100644 --- a/joern-cli/frontends/x2cpg/build.sbt +++ b/joern-cli/frontends/x2cpg/build.sbt @@ -4,10 +4,12 @@ dependsOn(Projects.semanticcpg) libraryDependencies ++= Seq( /* Start: AST Gen Dependencies */ - "com.lihaoyi" %% "upickle" % Versions.upickle, - "com.typesafe" % "config" % Versions.typeSafeConfig, - "com.michaelpollmeier" % "versionsort" % Versions.versionSort, + "com.lihaoyi" %% "upickle" % Versions.upickle, + "com.typesafe" % "config" % Versions.typeSafeConfig, + "com.michaelpollmeier" % "versionsort" % Versions.versionSort, + "org.apache.commons" % "commons-exec" % Versions.commonsExec, /* End: AST Gen Dependencies */ + "net.freeutils" % "jlhttp" % Versions.jlhttp, "org.gradle" % "gradle-tooling-api" % Versions.gradleTooling % Optional, "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index ea5f9f8bd5f4..df185d8addcf 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -1,11 +1,10 @@ package io.joern.x2cpg +import flatgraph.SchemaViolationException +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.nodes.AstNode.PropertyDefaults -import org.slf4j.LoggerFactory -import overflowdb.BatchedUpdate.DiffGraphBuilder -import overflowdb.SchemaViolationException case class AstEdge(src: NewNode, dst: NewNode) @@ -15,8 +14,6 @@ enum ValidationMode { object Ast { - private val logger = LoggerFactory.getLogger(getClass) - def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = Ast(Vector.empty :+ node) def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) @@ -49,6 +46,10 @@ object Ast { ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } + + ast.captureEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE) + } } def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit @@ -58,7 +59,7 @@ object Ast { !(src.isValidOutNeighbor(edge, dst) && dst.isValidInNeighbor(edge, src)) ) { throw new SchemaViolationException( - s"Malformed AST detected: (${src.label()}) -[$edge]-> (${dst.label()}) violates the schema." + s"Malformed AST detected: (${src.label}) -[$edge]-> (${dst.label}) violates the schema." ) } @@ -92,7 +93,8 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty, + captureEdges: collection.Seq[AstEdge] = Vector.empty )(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) { def root: Option[NewNode] = nodes.headOption @@ -114,7 +116,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + captureEdges = captureEdges ++ other.captureEdges ) } @@ -126,7 +129,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + captureEdges = captureEdges ++ other.captureEdges ) } @@ -217,20 +221,29 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) } + def withCaptureEdge(src: NewNode, dst: NewNode): Ast = { + Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE) + this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst))) + } + + def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = { + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)) + this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _))) + } + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` * fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new * copy. */ def subTreeCopy(node: AstNodeNew, argIndex: Int = -1, replacementNode: Option[AstNodeNew] = None): Ast = { - val newNode = replacementNode match + val newNode = replacementNode match { case Some(n) => n case None => node.copy + } if (argIndex != -1) { - // newNode.order = argIndex newNode match { - case expr: ExpressionNew => - expr.argumentIndex = argIndex - case _ => + case expr: ExpressionNew => expr.argumentIndex = argIndex + case _ => } } @@ -249,6 +262,7 @@ case class Ast( val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newCaptureEdges = captureEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) Ast(newNode) .copy( @@ -256,7 +270,8 @@ case class Ast( conditionEdges = newConditionEdges, refEdges = newRefEdges, bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges + receiverEdges = newReceiverEdges, + captureEdges = newCaptureEdges ) .withChildren(newChildren) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala index 8df7a930c78c..620cd2b1de1a 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala @@ -1,13 +1,11 @@ package io.joern.x2cpg import io.joern.x2cpg.passes.frontend.MetaDataPass -import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode -import io.shiftleft.codepropertygraph.generated.Cpg +import io.joern.x2cpg.utils.IntervalKeyPool +import io.joern.x2cpg.utils.NodeBuilders.{newFieldIdentifierNode, newMethodReturnNode, newOperatorCallNode} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Cpg, DiffGraphBuilder, ModifierTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, ModifierTypes} -import io.shiftleft.passes.IntervalKeyPool import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -import overflowdb.BatchedUpdate.DiffGraphBuilder abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: ValidationMode) { val diffGraph: DiffGraphBuilder = Cpg.newDiffGraphBuilder @@ -88,7 +86,7 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V ): Ast = Ast(method) .withChildren(parameters) - .withChild(Ast(NewBlock())) + .withChild(Ast(NewBlock().typeFullName(Defines.Any))) .withChildren(modifiers.map(Ast(_))) .withChild(Ast(methodReturn)) @@ -113,7 +111,7 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V methodNode.filename(fileName.get) } val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC) - val body = blockAst(NewBlock(), initAsts) + val body = blockAst(NewBlock().typeFullName(Defines.Any), initAsts) val methodReturn = newMethodReturnNode(returnType, None, None, None) methodAst(methodNode, Nil, body, methodReturn, List(staticModifier)) } @@ -150,9 +148,9 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V def wrapMultipleInBlock(asts: Seq[Ast], lineNumber: Option[Int]): Ast = { asts.toList match { - case Nil => blockAst(NewBlock().lineNumber(lineNumber)) + case Nil => blockAst(NewBlock().typeFullName(Defines.Any).lineNumber(lineNumber)) case ast :: Nil => ast - case astList => blockAst(NewBlock().lineNumber(lineNumber), astList) + case astList => blockAst(NewBlock().typeFullName(Defines.Any).lineNumber(lineNumber), astList) } } @@ -200,6 +198,13 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V ): Ast = forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst)) + private def setOrderExplicitly(ast: Ast, order: Int): Ast = { + ast.root match { + case Some(value: ExpressionNew) => value.order(order); ast + case _ => ast + } + } + def forAst( forNode: NewControlStructure, locals: Seq[Ast], @@ -208,12 +213,15 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V updateAsts: Seq[Ast], bodyAsts: Seq[Ast] ): Ast = { - val lineNumber = forNode.lineNumber + val lineNumber = forNode.lineNumber + val numOfLocals = locals.size + // for the expected orders see CfgCreator.cfgForForStatement + if (bodyAsts.nonEmpty) setOrderExplicitly(bodyAsts.head, numOfLocals + 4) Ast(forNode) .withChildren(locals) - .withChild(wrapMultipleInBlock(initAsts, lineNumber)) - .withChild(wrapMultipleInBlock(conditionAsts, lineNumber)) - .withChild(wrapMultipleInBlock(updateAsts, lineNumber)) + .withChild(setOrderExplicitly(wrapMultipleInBlock(initAsts, lineNumber), numOfLocals + 1)) + .withChild(setOrderExplicitly(wrapMultipleInBlock(conditionAsts, lineNumber), numOfLocals + 2)) + .withChild(setOrderExplicitly(wrapMultipleInBlock(updateAsts, lineNumber), numOfLocals + 3)) .withChildren(bodyAsts) .withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList) } @@ -310,8 +318,8 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V .withReceiverEdges(callNode, receiverRoot) } - def setArgumentIndices(arguments: Seq[Ast]): Unit = { - var currIndex = 1 + def setArgumentIndices(arguments: Seq[Ast], start: Int = 1): Unit = { + var currIndex = start arguments.foreach { a => a.root match { case Some(x: ExpressionNew) => @@ -324,6 +332,21 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V } } + def fieldAccessAst( + base: Ast, + code: String, + lineNo: Option[Int], + columnNo: Option[Int], + fieldName: String, + fieldTypeFullName: String, + fieldLineNo: Option[Int], + fieldColumnNo: Option[Int] + ): Ast = { + val callNode = newOperatorCallNode(Operators.fieldAccess, code, Some(fieldTypeFullName), lineNo, columnNo) + val fieldIdentifierNode = newFieldIdentifierNode(fieldName, fieldLineNo, fieldColumnNo) + callAst(callNode, Seq(base, Ast(fieldIdentifierNode))) + } + def withIndex[T, X](nodes: Seq[T])(f: (T, Int) => X): Seq[X] = nodes.zipWithIndex.map { case (x, i) => f(x, i + 1) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstNodeBuilder.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstNodeBuilder.scala index 2b2832960813..c7ff1e17b10b 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstNodeBuilder.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstNodeBuilder.scala @@ -1,7 +1,7 @@ package io.joern.x2cpg -import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode -import io.shiftleft.codepropertygraph.generated.nodes.Block.{PropertyDefaults => BlockDefaults} +import io.joern.x2cpg.utils.NodeBuilders.{newMethodReturnNode, newOperatorCallNode} +import io.shiftleft.codepropertygraph.generated.nodes.Block.PropertyDefaults as BlockDefaults import io.shiftleft.codepropertygraph.generated.nodes.{ NewAnnotation, NewBlock, @@ -79,15 +79,18 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => name: String, code: String, typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() + dynamicTypeHints: Seq[String] = Seq(), + genericSignature: Option[String] = None ): NewMember = { - NewMember() + val member = NewMember() .code(code) .name(name) .typeFullName(typeFullName) .dynamicTypeHintFullName(dynamicTypeHints) .lineNumber(line(node)) .columnNumber(column(node)) + genericSignature.foreach(member.genericSignature(_)) + member } protected def newImportNode(code: String, importedEntity: String, importedAs: String, include: Node): NewImport = { @@ -140,7 +143,8 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => astParentType: String = "", astParentFullName: String = "", inherits: Seq[String] = Seq.empty, - alias: Option[String] = None + alias: Option[String] = None, + genericSignature: Option[String] = None ): NewTypeDecl = { val node_ = NewTypeDecl() .name(name) @@ -157,6 +161,7 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => offset(node).foreach { case (offset, offsetEnd) => node_.offset(offset).offsetEnd(offsetEnd) } + genericSignature.foreach(node_.genericSignature(_)) node_ } @@ -217,6 +222,10 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => out } + protected def operatorCallNode(node: Node, name: String, typeFullName: Option[String]): NewCall = { + newOperatorCallNode(name, code(node), typeFullName, line(node), column(node)) + } + protected def returnNode(node: Node, code: String): NewReturn = { NewReturn() .code(code) @@ -234,7 +243,7 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => } protected def blockNode(node: Node): NewBlock = { - blockNode(node, BlockDefaults.Code, BlockDefaults.TypeFullName) + blockNode(node, BlockDefaults.Code, Defines.Any) } protected def blockNode(node: Node, code: String, typeFullName: String): NewBlock = { @@ -258,15 +267,19 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => name: String, code: String, typeFullName: String, - closureBindingId: Option[String] = None - ): NewLocal = - NewLocal() + closureBindingId: Option[String] = None, + genericSignature: Option[String] = None + ): NewLocal = { + val local = NewLocal() .name(name) .code(code) .typeFullName(typeFullName) .closureBindingId(closureBindingId) .lineNumber(line(node)) .columnNumber(column(node)) + genericSignature.foreach(local.genericSignature(_)) + local + } protected def identifierNode( node: Node, @@ -296,7 +309,8 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => signature: Option[String], fileName: String, astParentType: Option[String] = None, - astParentFullName: Option[String] = None + astParentFullName: Option[String] = None, + genericSignature: Option[String] = None ): NewMethod = { val node_ = NewMethod() @@ -312,6 +326,7 @@ trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => .lineNumberEnd(lineEnd(node)) .columnNumberEnd(columnEnd(node)) signature.foreach { s => node_.signature(StringUtils.normalizeSpace(s)) } + genericSignature.foreach(node_.genericSignature(_)) offset(node).foreach { case (offset, offsetEnd) => node_.offset(offset).offsetEnd(offsetEnd) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Defines.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Defines.scala index 13f523ab068e..1de9aa321c23 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Defines.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Defines.scala @@ -36,4 +36,9 @@ object Defines { val LeftAngularBracket = "<" val Unknown = "" + + // Used for field access calls in the lowering of pattern extractors where the field name + // may not be known. As an example in javasrc2cpg, the assignment for `o instanceof Foo(Bar b))` could + // be lowered to `Bar b = (Bar) (((Foo) o).)` + val UnknownField = "" } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Imports.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Imports.scala index a46e41fdabaf..7d523c6aba44 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Imports.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Imports.scala @@ -2,7 +2,7 @@ package io.joern.x2cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{CallBase, NewImport} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder object Imports { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/SourceFiles.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/SourceFiles.scala index 24b224a15fab..af768312b594 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/SourceFiles.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/SourceFiles.scala @@ -1,21 +1,91 @@ package io.joern.x2cpg -import better.files.File.VisitOptions import better.files.* +import better.files.File.VisitOptions import org.slf4j.LoggerFactory import java.io.FileNotFoundException +import java.nio.file.FileVisitor +import java.nio.file.FileVisitResult +import java.nio.file.Path import java.nio.file.Paths +import java.nio.file.attribute.BasicFileAttributes +import java.nio.file.Files +import scala.jdk.CollectionConverters.SetHasAsJava import scala.util.matching.Regex object SourceFiles { private val logger = LoggerFactory.getLogger(getClass) + /** A failsafe implementation of a [[FileVisitor]] that continues iterating through files even if an [[IOException]] + * occurs during traversal. + * + * This visitor determines during traversal whether a given file should be excluded based on several criteria, such + * as matching default ignore patterns, specific file name patterns, or explicit file paths to ignore. It does not + * descent into folders matching such ignore patterns. + * + * This class is useful in scenarios where file traversal must be resilient to errors, such as accessing files with + * restricted permissions or encountering corrupted file entries. + * + * @param inputPath + * The root path from which the file traversal starts. + * @param ignoredDefaultRegex + * Optional sequence of regular expressions to filter out default ignored file patterns. + * @param ignoredFilesRegex + * Optional regular expression to filter out specific files based on their names. + * @param ignoredFilesPath + * Optional sequence of file paths to exclude from traversal explicitly. + */ + private final class FailsafeFileVisitor( + inputPath: String, + sourceFileExtensions: Set[String], + ignoredDefaultRegex: Option[Seq[Regex]] = None, + ignoredFilesRegex: Option[Regex] = None, + ignoredFilesPath: Option[Seq[String]] = None + ) extends FileVisitor[Path] { + + private val seenFiles = scala.collection.mutable.ArrayBuffer.empty[Path] + + def files(): Array[File] = seenFiles.map(File(_)).toArray + + override def preVisitDirectory(dir: Path, attrs: BasicFileAttributes): FileVisitResult = { + if (filterFile(dir.toString, inputPath, ignoredDefaultRegex, ignoredFilesRegex, ignoredFilesPath)) { + FileVisitResult.CONTINUE + } else { + FileVisitResult.SKIP_SUBTREE + } + } + + override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = { + if ( + hasSourceFileExtension(file, sourceFileExtensions) && + filterFile(file.toString, inputPath, ignoredDefaultRegex, ignoredFilesRegex, ignoredFilesPath) + ) { seenFiles.addOne(file) } + FileVisitResult.CONTINUE + } + + override def visitFileFailed(file: Path, exc: java.io.IOException): FileVisitResult = { + exc match { + case _: java.nio.file.FileSystemLoopException => logger.warn(s"Ignoring '$file' (cyclic symlink)") + case other => logger.warn(s"Ignoring '$file'", other) + } + FileVisitResult.CONTINUE + } + + override def postVisitDirectory(dir: Path, exc: java.io.IOException): FileVisitResult = FileVisitResult.CONTINUE + } + private def isIgnoredByFileList(filePath: String, ignoredFiles: Seq[String]): Boolean = { - val isInIgnoredFiles = ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath + val filePathFile = File(filePath) + if (!filePathFile.exists || !filePathFile.isReadable) { + logger.debug(s"'$filePath' ignored (not readable or broken symlink)") + return true + } + val isInIgnoredFiles = ignoredFiles.exists { ignorePath => + val ignorePathFile = File(ignorePath) + ignorePathFile.exists && + (ignorePathFile.contains(filePathFile, strict = false) || ignorePathFile.isSameFileAs(filePathFile)) } if (isInIgnoredFiles) { logger.debug(s"'$filePath' ignored (--exclude)") @@ -46,13 +116,23 @@ object SourceFiles { } } - /** Method to filter file based on the passed parameters + /** Filters a file based on the provided ignore rules. + * + * This method determines whether a given file should be excluded from processing based on several criteria, such as + * matching default ignore patterns, specific file name patterns, or explicit file paths to ignore. + * * @param file + * The file name or path to evaluate. * @param inputPath + * The root input path for the file traversal. * @param ignoredDefaultRegex + * Optional sequence of regular expressions defining default file patterns to ignore. * @param ignoredFilesRegex + * Optional regular expression defining specific file name patterns to ignore. * @param ignoredFilesPath + * Optional sequence of file paths to explicitly exclude. * @return + * `true` if the file is accepted, i.e., does not match any of the ignore criteria, `false` otherwise. */ def filterFile( file: String, @@ -64,7 +144,25 @@ object SourceFiles { && !ignoredFilesRegex.exists(isIgnoredByRegex(file, inputPath, _)) && !ignoredFilesPath.exists(isIgnoredByFileList(file, _)) - private def filterFiles( + /** Filters a list of files based on the provided ignore rules. + * + * This method applies [[filterFile]] to each file in the input list, returning only those files that do not match + * any of the ignore criteria. + * + * @param files + * The list of file names or paths to evaluate. + * @param inputPath + * The root input path for the file traversal. + * @param ignoredDefaultRegex + * Optional sequence of regular expressions defining default file patterns to ignore. + * @param ignoredFilesRegex + * Optional regular expression defining specific file name patterns to ignore. + * @param ignoredFilesPath + * Optional sequence of file paths to explicitly exclude. + * @return + * A filtered list of files that do not match the ignore criteria. + */ + def filterFiles( files: List[String], inputPath: String, ignoredDefaultRegex: Option[Seq[Regex]] = None, @@ -72,8 +170,49 @@ object SourceFiles { ignoredFilesPath: Option[Seq[String]] = None ): List[String] = files.filter(filterFile(_, inputPath, ignoredDefaultRegex, ignoredFilesRegex, ignoredFilesPath)) - /** For given input paths, determine all source files by inspecting filename extensions and filter the result if - * following arguments ignoredDefaultRegex, ignoredFilesRegex and ignoredFilesPath are used + private def hasSourceFileExtension(file: File, sourceFileExtensions: Set[String]): Boolean = + sourceFileExtensions.exists(ext => file.pathAsString.endsWith(ext)) + + /** Determines a sorted list of file paths in a directory that match the specified criteria. + * + * @param inputPath + * The root directory to search for files. + * @param sourceFileExtensions + * A set of file extensions to include in the search. + * @param ignoredDefaultRegex + * An optional sequence of regular expressions for default files to ignore. + * @param ignoredFilesRegex + * An optional regular expression for additional files to ignore. + * @param ignoredFilesPath + * An optional sequence of specific file paths to ignore. + * @param visitOptions + * Implicit parameter defining the options for visiting the file tree. Defaults to `VisitOptions.follow`, which + * follows symbolic links. + * @return + * A sorted `List[String]` of file paths matching the criteria. + * + * This function traverses the file tree starting at the given `inputPath` and collects file paths that: + * - Have extensions specified in `sourceFileExtensions`. + * - Are not ignored based on `ignoredDefaultRegex`, `ignoredFilesRegex`, or `ignoredFilesPath`. + * + * It uses a custom `FailsafeFileVisitor` to handle the filtering logic and `Files.walkFileTree` to perform the + * traversal. + * + * Example usage: + * {{{ + * val files = determine( + * inputPath = "/path/to/dir", + * sourceFileExtensions = Set(".scala", ".java"), + * ignoredDefaultRegex = Some(Seq(".*\\.tmp".r)), + * ignoredFilesRegex = Some(".*_backup\\.scala".r), + * ignoredFilesPath = Some(Seq("/path/to/dir/ignore_me.scala")) + * ) + * println(files) + * }}} + * @throws java.io.FileNotFoundException + * if the `inputPath` does not exist or is not readable. + * @see + * [[FailsafeFileVisitor]] for details on the visitor used to process files. */ def determine( inputPath: String, @@ -82,62 +221,38 @@ object SourceFiles { ignoredFilesRegex: Option[Regex] = None, ignoredFilesPath: Option[Seq[String]] = None )(implicit visitOptions: VisitOptions = VisitOptions.follow): List[String] = { - filterFiles( - determine(Set(inputPath), sourceFileExtensions), - inputPath, + val dir = File(inputPath) + assertExists(dir) + val visitor = new FailsafeFileVisitor( + dir.pathAsString, + sourceFileExtensions, ignoredDefaultRegex, ignoredFilesRegex, ignoredFilesPath ) + Files.walkFileTree(dir.path, visitOptions.toSet.asJava, Int.MaxValue, visitor) + val matchingFiles = visitor.files().map(_.pathAsString) + matchingFiles.toList.sorted } - /** For a given array of input paths, determine all source files by inspecting filename extensions. - */ - def determine(inputPaths: Set[String], sourceFileExtensions: Set[String])(implicit - visitOptions: VisitOptions - ): List[String] = { - def hasSourceFileExtension(file: File): Boolean = - file.extension.exists(sourceFileExtensions.contains) - - val inputFiles = inputPaths.map(File(_)) - assertAllExist(inputFiles) - - val (dirs, files) = inputFiles.partition(_.isDirectory) - - val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) - val matchingFilesFromDirs = dirs - .flatMap(_.listRecursively) - .filter(hasSourceFileExtension) - .map(_.pathAsString) - - (matchingFiles ++ matchingFilesFromDirs).toList.sorted - } - - /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution early to avoid - * unexpected and hard-to-debug issues in the results. + /** Asserts that a given file exists and is readable. + * + * This method validates the existence and readability of the specified file. If the file does not exist or is not + * readable, it logs an error and throws a [[FileNotFoundException]]. + * + * @param file + * The file to validate. + * @throws FileNotFoundException + * if the file does not exist or is not readable. */ - private def assertAllExist(files: Set[File]): Unit = { - val (existant, nonExistant) = files.partition(_.isReadable) - val nonReadable = existant.filterNot(_.isReadable) - - if (nonExistant.nonEmpty || nonReadable.nonEmpty) { - logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) - - logErrorWithPaths("Source input paths exist, but are not readable", nonReadable.map(_.canonicalPath)) - - throw FileNotFoundException("Invalid source paths provided") + private def assertExists(file: File): Unit = { + if (!file.exists) { + logger.error(s"Source input path does not exist: ${file.pathAsString}") + throw FileNotFoundException("Invalid source path provided!") } - } - - private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = { - val pathsArray = paths.toArray.sorted - - pathsArray.lengthCompare(1) match { - case cmp if cmp < 0 => // pathsArray is empty, so don't log anything - case cmp if cmp == 0 => logger.error(s"$message: ${paths.head}") - case _ => - val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") - logger.error(errorMessage) + if (!file.isReadable) { + logger.error(s"Source input path exists, but is not readable: ${file.pathAsString}") + throw FileNotFoundException("Invalid source path provided!") } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/X2Cpg.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/X2Cpg.scala index ca5b223ec8b1..a8b3a5b9f229 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/X2Cpg.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/X2Cpg.scala @@ -2,11 +2,11 @@ package io.joern.x2cpg import better.files.File import io.joern.x2cpg.X2Cpg.{applyDefaultOverlays, withErrorsToConsole} +import io.joern.x2cpg.frontendspecific.FrontendArgsDelimitor import io.joern.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext} import org.slf4j.LoggerFactory -import overflowdb.Config import scopt.OParser import java.io.PrintWriter @@ -15,12 +15,14 @@ import scala.util.matching.Regex import scala.util.{Failure, Success, Try} object X2CpgConfig { + def defaultInputPath: String = "" def defaultOutputPath: String = "cpg.bin" } trait X2CpgConfig[R <: X2CpgConfig[R]] { - var inputPath: String = "" - var outputPath: String = X2CpgConfig.defaultOutputPath + var inputPath: String = X2CpgConfig.defaultInputPath + var outputPath: String = X2CpgConfig.defaultOutputPath + var serverMode: Boolean = false def withInputPath(inputPath: String): R = { this.inputPath = Paths.get(inputPath).toAbsolutePath.normalize().toString @@ -32,6 +34,11 @@ trait X2CpgConfig[R <: X2CpgConfig[R]] { this.asInstanceOf[R] } + def withServerMode(x: Boolean): R = { + this.serverMode = x + this.asInstanceOf[R] + } + var defaultIgnoredFilesRegex: Seq[Regex] = Seq.empty var ignoredFilesRegex: Regex = "".r var ignoredFiles: Seq[String] = Seq.empty @@ -74,6 +81,7 @@ trait X2CpgConfig[R <: X2CpgConfig[R]] { def withInheritedFields(config: R): R = { this.inputPath = config.inputPath this.outputPath = config.outputPath + this.serverMode = config.serverMode this.defaultIgnoredFilesRegex = config.defaultIgnoredFilesRegex this.ignoredFilesRegex = config.ignoredFilesRegex this.ignoredFiles = config.ignoredFiles @@ -113,9 +121,10 @@ object DependencyDownloadConfig { * @param frontend * the frontend to use for CPG creation */ -abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[?]](val cmdLineParser: OParser[Unit, T], frontend: X)( - implicit defaultConfig: T -) { +abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[T]]( + val cmdLineParser: OParser[Unit, T], + val frontend: X +)(implicit defaultConfig: T) { private val logger = LoggerFactory.getLogger(classOf[X2CpgMain[T, X]]) @@ -149,7 +158,6 @@ abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[?]](val cmdLine run(config, frontend) } catch { case ex: Throwable => - println(ex.getMessage) ex.printStackTrace() System.exit(1) } @@ -163,7 +171,7 @@ abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[?]](val cmdLine /** Trait that represents a CPG generator, where T is the frontend configuration class. */ -trait X2CpgFrontend[T <: X2CpgConfig[?]] { +trait X2CpgFrontend[T <: X2CpgConfig[T]] { /** Create a CPG according to given configuration. Returns CPG wrapped in a `Try`, making it possible to detect and * inspect exceptions in CPG generation. To be provided by the frontend. @@ -173,15 +181,24 @@ trait X2CpgFrontend[T <: X2CpgConfig[?]] { /** Create CPG according to given configuration, printing errors to the console if they occur. The CPG is closed and * not returned. */ + @throws[Throwable]("if createCpg throws any Throwable") def run(config: T): Unit = { withErrorsToConsole(config) { _ => createCpg(config) match { case Success(cpg) => - cpg.close() + cpg.close() // persists to disk Success(cpg) case Failure(exception) => Failure(exception) } + } match { + case Failure(exception) => + // We explicitly rethrow the exception so that every frontend will + // terminate with exit code 1 if there was an exception during createCpg. + // Frontend maintainer may want to catch that RuntimeException on their end + // to add custom error handling. + throw exception + case Success(_) => // this is fine } } @@ -209,12 +226,12 @@ trait X2CpgFrontend[T <: X2CpgConfig[?]] { * exists, it is the file name of the resulting CPG. Otherwise, the CPG is held in memory. */ def createCpg(inputName: String, outputName: Option[String])(implicit defaultConfig: T): Try[Cpg] = { - val defaultWithInputPath = defaultConfig.withInputPath(inputName).asInstanceOf[T] + val defaultWithInputPath = defaultConfig.withInputPath(inputName) val config = if (!outputName.contains(X2CpgConfig.defaultOutputPath)) { if (outputName.isEmpty) { - defaultWithInputPath.withOutputPath("").asInstanceOf[T] + defaultWithInputPath.withOutputPath("") } else { - defaultWithInputPath.withOutputPath(outputName.get).asInstanceOf[T] + defaultWithInputPath.withOutputPath(outputName.get) } } else { defaultWithInputPath @@ -260,8 +277,14 @@ object X2Cpg { .action { (x, c) => c.withOutputPath(x) }, + + // previously this was supposed to be called with `,` as a separator, + // e.g. `--exclude foo,bar` - which (among others) has the disadvantage + // that under windows a `,` is treated as an argument separator + // better: provide this argument multiple times, i.e. `--exclude foo --exclude bar` opt[Seq[String]]("exclude") - .valueName(",,...") + .valueName("") + .unbounded() .action { (x, c) => c.ignoredFiles = c.ignoredFiles ++ x.map(c.createPathForIgnore) c @@ -281,6 +304,10 @@ object X2Cpg { .text( "add the raw source code to the content field of FILE nodes to allow for method source retrieval via offset fields (disabled by default)" ), + opt[Unit]("server") + .action((_, c) => c.withServerMode(true)) + .hidden() + .text("runs this frontend in server mode (disabled by default)"), opt[Unit]("disable-file-content") .action((_, c) => c.withDisableFileContent(true)) .hidden() @@ -293,19 +320,16 @@ object X2Cpg { /** Create an empty CPG, backed by the file at `optionalOutputPath` or in-memory if `optionalOutputPath` is empty. */ def newEmptyCpg(optionalOutputPath: Option[String] = None): Cpg = { - val odbConfig = optionalOutputPath - .map { outputPath => - val outFile = File(outputPath) + optionalOutputPath match { + case Some(outputPath) => + lazy val outFile = File(outputPath) if (outputPath != "" && outFile.exists) { logger.info("Output file exists, removing: " + outputPath) outFile.delete() } - Config.withDefaults.withStorageLocation(outputPath) - } - .getOrElse { - Config.withDefaults() - } - Cpg.withConfig(odbConfig) + Cpg.withStorage(outFile.path) + case None => Cpg.empty + } } /** Apply function `applyPasses` to a newly created CPG. The CPG is wrapped in a `Try` and returned. On failure, the @@ -368,7 +392,7 @@ object X2Cpg { } /** Strips surrounding quotation characters from a string. - * @param s + * @param str * the target string. * @return * the stripped string. diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/astgen/AstGenRunner.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/astgen/AstGenRunner.scala index c9a9a514efab..799ed3da50ca 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/astgen/AstGenRunner.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/astgen/AstGenRunner.scala @@ -46,24 +46,14 @@ object AstGenRunner { packagePath: URL ) - def executableDir(implicit metaData: AstGenProgramMetaData): String = { - val dir = metaData.packagePath.toString - val indexOfLib = dir.lastIndexOf("lib") - val fixedDir = if (indexOfLib != -1) { - new java.io.File(dir.substring("file:".length, indexOfLib)).toString - } else { - val indexOfTarget = dir.lastIndexOf("target") - if (indexOfTarget != -1) { - new java.io.File(dir.substring("file:".length, indexOfTarget)).toString - } else { - "." - } - } - Paths.get(fixedDir, "/bin/astgen").toAbsolutePath.toString - } + def executableDir(implicit metaData: AstGenProgramMetaData): String = + ExternalCommand + .executableDir(Paths.get(metaData.packagePath.toURI)) + .resolve("astgen") + .toString def hasCompatibleAstGenVersion(compatibleVersion: String)(implicit metaData: AstGenProgramMetaData): Boolean = { - ExternalCommand.run(s"$metaData.name -version", ".").toOption.map(_.mkString.strip()) match { + ExternalCommand.run(Seq(metaData.name, "-version"), ".").successOption.map(_.mkString.strip()) match { case Some(installedVersion) if installedVersion != "unknown" && Try(VersionHelper.compare(installedVersion, compatibleVersion)).toOption.getOrElse(-1) >= 0 => @@ -74,7 +64,8 @@ object AstGenRunner { s"Found local ${metaData.name} v$installedVersion in systems PATH but ${metaData.name} requires at least v$compatibleVersion" ) false - case _ => false + case _ => + false } } @@ -124,7 +115,9 @@ trait AstGenRunnerBase(config: X2CpgConfig[?] & AstGenConfig[?]) { } } - private def executableName(x86Suffix: String, armSuffix: String)(implicit metaData: AstGenProgramMetaData): String = { + protected def executableName(x86Suffix: String, armSuffix: String)(implicit + metaData: AstGenProgramMetaData + ): String = { if (metaData.multiArchitectureBuilds) { s"${metaData.name}-$x86Suffix" } else { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/javasrc2cpg/JavaTypeRecoveryPassGenerator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/javasrc2cpg/JavaTypeRecoveryPassGenerator.scala index 607bc70fa3dd..936595fc045b 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/javasrc2cpg/JavaTypeRecoveryPassGenerator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/javasrc2cpg/JavaTypeRecoveryPassGenerator.scala @@ -5,7 +5,7 @@ import io.joern.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.generated.{Cpg, PropertyNames} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class JavaTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPassGenerator[Method](cpg, config) { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ConstClosurePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ConstClosurePass.scala index fedb8e0f77d7..f9b4b0f339c9 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ConstClosurePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ConstClosurePass.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Method, MethodRef} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** A pass that identifies assignments of closures to constants and updates `METHOD` nodes accordingly. */ diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/JavaScriptTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/JavaScriptTypeRecovery.scala index 3264b4e73f05..2d9c8072f1f2 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/JavaScriptTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/JavaScriptTypeRecovery.scala @@ -5,10 +5,10 @@ import io.joern.x2cpg.Defines.ConstructorMethodName import io.joern.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Operators, Properties, PropertyNames} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class JavaScriptTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPassGenerator[File](cpg, config) { @@ -48,9 +48,9 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match { case x @ (_: Identifier | _: Local | _: MethodParameterIn) - if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => - val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) - val typeHints = symbolTable.get(LocalVar(x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any))) - typeFullName + if x.propertyOption(Properties.TypeFullName).getOrElse(Defines.Any) != Defines.Any => + val typeFullName = x.propertyOption(Properties.TypeFullName).getOrElse(Defines.Any) + val typeHints = symbolTable.get(LocalVar(typeFullName)) - typeFullName lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet val resolvedTypeHints = if (typeHints.nonEmpty) symbolTable.put(x, typeHints) @@ -59,9 +59,8 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui if (!resolvedTypeHints.contains(typeFullName) && resolvedTypeHints.sizeIs == 1) builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) - case x @ (_: Identifier | _: Local | _: MethodParameterIn) - if x.property(PropertyNames.POSSIBLE_TYPES, Seq.empty[String]).nonEmpty => - val possibleTypes = x.property(PropertyNames.POSSIBLE_TYPES, Seq.empty[String]) + case x @ (_: Identifier | _: Local | _: MethodParameterIn) if x.property(Properties.PossibleTypes).nonEmpty => + val possibleTypes = x.property(Properties.PossibleTypes) if (possibleTypes.sizeIs == 1 && !possibleTypes.contains("ANY")) { val typeFullName = possibleTypes.head val typeHints = symbolTable.get(LocalVar(typeFullName)) - typeFullName diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ObjectPropertyCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ObjectPropertyCallLinker.scala index cb09e7df9d4d..4b77b193ed03 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ObjectPropertyCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/jssrc2cpg/ObjectPropertyCallLinker.scala @@ -3,7 +3,6 @@ package io.joern.x2cpg.frontendspecific.jssrc2cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, MethodRef} import io.shiftleft.codepropertygraph.generated.{Cpg, PropertyNames} import io.shiftleft.passes.CpgPass -import overflowdb.BatchedUpdate import io.shiftleft.semanticcpg.language.* /** Perform a simple analysis to find a common pattern in JavaScript where objects are dynamically assigned function @@ -13,7 +12,7 @@ import io.shiftleft.semanticcpg.language.* */ class ObjectPropertyCallLinker(cpg: Cpg) extends CpgPass(cpg) { - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = { + override def run(builder: DiffGraphBuilder): Unit = { def propertyCallRegexPattern(withMatchingGroup: Boolean): String = "^(?:\\{.*\\}|.*):\\(" + (if withMatchingGroup then "(.*)" else ".*") + "\\):.*$" diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/package.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/package.scala index 7434394eb0b4..e6eebfa8eba4 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/package.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/package.scala @@ -10,4 +10,7 @@ package io.joern.x2cpg * Otherwise we'll end in jar hell with various incompatible versions of many different dependencies, and complex * issues with things like OSGI and JPMS. */ -package object frontendspecific +package object frontendspecific { + // Special string used to separate joern-parse opts from frontend-specific opts + val FrontendArgsDelimitor = "--frontend-args" +} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeRecovery.scala index 9c7857a95ccd..6399506585c6 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeRecovery.scala @@ -8,7 +8,7 @@ import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, Operators, import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import scala.collection.mutable @@ -150,7 +150,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph symbolTable.append(head, callees) case _ => Set.empty } - val returnTypes = extractTypes(ret.argumentOut.l) + val returnTypes = extractTypes(ret.argumentOut.cast[CfgNode].l) existingTypes.addAll(returnTypes) /* Check whether method return is already known, and if so, remove dummy value */ @@ -221,7 +221,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph .getOrElse(XTypeRecovery.DummyIndexAccess) else x.name - val collectionVar = Option(c.argumentOut.l match { + val collectionVar = Option(c.argumentOut.cast[CfgNode].l match { case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeStubsParser.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeStubsParser.scala index e8f2cb4cec7c..c7d43c33f4e9 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeStubsParser.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/php2cpg/PhpTypeStubsParser.scala @@ -9,7 +9,6 @@ import io.shiftleft.passes.ForkJoinParallelCpgPass import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate import scopt.OParser import java.io.File as JFile @@ -47,7 +46,7 @@ class PhpTypeStubsParserPass(cpg: Cpg, config: XTypeStubsParserConfig = XTypeStu arr } - override def runOnPart(builder: overflowdb.BatchedUpdate.DiffGraphBuilder, part: KnownFunction): Unit = { + override def runOnPart(builder: DiffGraphBuilder, part: KnownFunction): Unit = { /* calculate the result of this part - this is done as a concurrent task */ val builtinMethod = cpg.method.fullNameExact(part.name).l builtinMethod.foreach(mNode => { @@ -73,7 +72,7 @@ class PhpTypeStubsParserPass(cpg: Cpg, config: XTypeStubsParserConfig = XTypeStu def scanParamTypes(pTypesRawArr: List[String]): Seq[Seq[String]] = pTypesRawArr.map(paramTypeRaw => paramTypeRaw.split(",").map(_.strip).toSeq).toSeq - protected def setTypes(builder: overflowdb.BatchedUpdate.DiffGraphBuilder, n: StoredNode, types: Seq[String]): Unit = + protected def setTypes(builder: DiffGraphBuilder, n: StoredNode, types: Seq[String]): Unit = if (types.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/Constants.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/Constants.scala index 328bb2612c18..68ab95f21990 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/Constants.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/Constants.scala @@ -10,4 +10,7 @@ object Constants { val builtinIntType = s"${builtinPrefix}int" val builtinFloatType = s"${builtinPrefix}float" val builtinComplexType = s"${builtinPrefix}complex" + + val moduleName = "" + val initName = "__init__" } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/DynamicTypeHintFullNamePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/DynamicTypeHintFullNamePass.scala index 5675afe23e63..140de6685fc6 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/DynamicTypeHintFullNamePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/DynamicTypeHintFullNamePass.scala @@ -5,7 +5,6 @@ import io.shiftleft.codepropertygraph.generated.{Cpg, PropertyNames} import io.shiftleft.codepropertygraph.generated.nodes.{CfgNode, MethodParameterIn, MethodReturn, StoredNode} import io.shiftleft.passes.ForkJoinParallelCpgPass import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate import java.io.File import java.util.regex.{Matcher, Pattern} @@ -71,10 +70,10 @@ class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgN } private def pythonicTypeNameToImport(fullName: String): String = - fullName.replaceFirst("\\.py:", "").replaceAll(Pattern.quote(File.separator), ".") + fullName.replaceFirst(s"\\.py:${Constants.moduleName}", "").replaceAll(Pattern.quote(File.separator), ".") private def setTypeHints( - diffGraph: BatchedUpdate.DiffGraphBuilder, + diffGraph: DiffGraphBuilder, node: StoredNode, typeHint: String, alias: String, @@ -85,7 +84,7 @@ class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgN val typeFilePath = typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator)) val pythonicTypeFullName = importFullPath.split("\\.").lastOption match { case Some(typeName) => - typeFilePath.stripSuffix(s"${File.separator}$typeName").concat(s".py:.$typeName") + typeFilePath.stripSuffix(s"${File.separator}$typeName").concat(s".py:${Constants.moduleName}.$typeName") case None => typeHintFullName } cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonImportResolverPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonImportResolverPass.scala index 61a0f3b08b36..66fb8347706c 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonImportResolverPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonImportResolverPass.scala @@ -23,7 +23,7 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { private val moduleCache: mutable.HashMap[String, ImportableEntity] = mutable.HashMap.empty override def init(): Unit = { - cpg.typeDecl.isExternal(false).nameExact("").foreach { moduleType => + cpg.typeDecl.isExternal(false).nameExact(Constants.moduleName).foreach { moduleType => val modulePath = fileToPythonImportNotation(moduleType.filename) cpg.method.fullNameExact(moduleType.fullName).headOption.foreach { moduleMethod => moduleCache.put(modulePath, Module(moduleType, moduleMethod)) @@ -48,7 +48,7 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { .stripPrefix(codeRootDir) .replaceAll(Matcher.quoteReplacement(JFile.separator), ".") .stripSuffix(".py") - .stripSuffix(".__init__") + .stripSuffix(s".${Constants.initName}") override protected def optionalResolveImport( fileName: String, @@ -103,16 +103,20 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { def toUnresolvedImport(pseudoPath: String): Set[EvaluatedImport] = { if (isMaybeConstructor) { - Set(UnknownMethod(Seq(pseudoPath, "__init__").mkString(pathSep.toString), alias), UnknownTypeDecl(pseudoPath)) + Set( + UnknownMethod(Seq(pseudoPath, Constants.initName).mkString(pathSep.toString), alias), + UnknownTypeDecl(pseudoPath) + ) } else { Set(UnknownImport(pseudoPath)) } } expEntity.split(pathSep).reverse.toList match - case name :: Nil => toUnresolvedImport(s"$name.py:") - case name :: xs => toUnresolvedImport(s"${xs.reverse.mkString(JFile.separator)}.py:$pathSep$name") - case Nil => Set.empty + case name :: Nil => toUnresolvedImport(s"$name.py:${Constants.moduleName}") + case name :: xs => + toUnresolvedImport(s"${xs.reverse.mkString(JFile.separator)}.py:${Constants.moduleName}$pathSep$name") + case Nil => Set.empty } private sealed trait ImportableEntity { @@ -140,6 +144,6 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { private case class ImportableType(typ: TypeDecl) extends ImportableEntity { override def toResolvedImport(alias: String): List[EvaluatedImport] = - List(ResolvedTypeDecl(typ.fullName), ResolvedMethod(s"${typ.fullName}.__init__", typ.name)) + List(ResolvedTypeDecl(typ.fullName), ResolvedMethod(s"${typ.fullName}.${Constants.initName}", typ.name)) } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonInheritanceNamePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonInheritanceNamePass.scala index aab1e5d803d0..2987881dd226 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonInheritanceNamePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonInheritanceNamePass.scala @@ -8,7 +8,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg */ class PythonInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg) { - override val moduleName: String = "" + override val moduleName: String = Constants.moduleName override val fileExt: String = ".py" } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeHintCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeHintCallLinker.scala index cfaa52cc2d9d..fbf5dd2fa33d 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeHintCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeHintCallLinker.scala @@ -8,11 +8,11 @@ import io.shiftleft.semanticcpg.language.* class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { - override def calls: Iterator[Call] = super.calls.nameNot("^(import).*") + override def calls: Iterator[Call] = super.calls.whereNot(_.isImport) override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map { // Python call from a type - case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__" + case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.${Constants.initName}" // Python call from a function pointer case typ => typ } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeRecovery.scala index 6449cfc8e013..d652e1641708 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/pysrc2cpg/PythonTypeRecovery.scala @@ -1,19 +1,13 @@ package io.joern.x2cpg.frontendspecific.pysrc2cpg -import io.joern.x2cpg.passes.frontend.{RecoverForXCompilationUnit, XTypeRecovery, XTypeRecoveryState} -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.File -import io.shiftleft.semanticcpg.language.* -import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants +import io.joern.x2cpg.Defines import io.joern.x2cpg.passes.frontend.* -import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder, Operators, PropertyNames} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.importresolver.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess -import overflowdb.BatchedUpdate.DiffGraphBuilder private class PythonTypeRecovery(cpg: Cpg, state: XTypeRecoveryState, iteration: Int) extends XTypeRecovery[File](cpg, state, iteration) { @@ -56,7 +50,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder .member .nameExact(memberName) .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot(_ == "ANY") + .filterNot(_ == Defines.Any) .toSet symbolTable.put(LocalVar(entityName), memberTypes) case UnknownMethod(fullName, alias, receiver, _) => @@ -94,7 +88,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder /** If the parent method is module then it can be used as a field. */ override def isFieldUncached(i: Identifier): Boolean = - i.method.name.matches("(|__init__)") || super.isFieldUncached(i) + i.method.name.matches(s"(${Constants.moduleName}|${Constants.initName})") || super.isFieldUncached(i) override def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { operation match { @@ -111,7 +105,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { - val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) + val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"$pathSep${Constants.initName}")) associateTypes(i, constructorPaths) } @@ -144,7 +138,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } override def getFieldParents(fa: FieldAccess): Set[String] = { - if (fa.method.name == "") { + if (fa.method.name == Constants.moduleName) { Set(fa.method.fullName) } else if (fa.method.typeDecl.nonEmpty) { val parentTypes = fa.method.typeDecl.fullName.toSet @@ -204,7 +198,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder .foreach { cls => val clsPath = classMethod.typeDecl.fullName.toSet symbolTable.put(LocalVar(cls.name), clsPath) - if (cls.typeFullName == "ANY") + if (cls.typeFullName == Defines.Any) builder.setNodeProperty(cls, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, clsPath.toSeq) } } @@ -224,7 +218,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder funcName: String, baseName: Option[String] ): Unit = { - if (funcName != "") + if (funcName != Constants.moduleName) super.handlePotentialFunctionPointer(funcPtr, baseTypes, funcName, baseName) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/Constants.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/Constants.scala new file mode 100644 index 000000000000..821a95293b68 --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/Constants.scala @@ -0,0 +1,93 @@ +package io.joern.x2cpg.frontendspecific.rubysrc2cpg + +object Constants { + + val builtinPrefix = "__core" + val kernelPrefix = s"$builtinPrefix.Kernel" + val Initialize = "initialize" + val Main = "
" + + /* Source: https://ruby-doc.org/3.2.2/Kernel.html + * + * We comment-out methods that require an explicit "receiver" (target of member access.) + */ + val kernelFunctions: Set[String] = Set( + "Array", + "Complex", + "Float", + "Hash", + "Integer", + "Rational", + "String", + "__callee__", + "__dir__", + "__method__", + "abort", + "at_exit", + "autoload", + "autoload?", + "binding", + "block_given?", + "callcc", + "caller", + "caller_locations", + "catch", + "chomp", + "chomp!", + "chop", + "chop!", + // "class", + // "clone", + "eval", + "exec", + "exit", + "exit!", + "fail", + "fork", + "format", + // "frozen?", + "gets", + "global_variables", + "gsub", + "gsub!", + "iterator?", + "lambda", + "load", + "local_variables", + "loop", + "open", + "p", + "print", + "printf", + "proc", + "putc", + "puts", + "raise", + "rand", + "readline", + "readlines", + "require", + "require_all", + "require_relative", + "select", + "set_trace_func", + "sleep", + "spawn", + "sprintf", + "srand", + "sub", + "sub!", + "syscall", + "system", + "tap", + "test", + // "then", + "throw", + "trace_var", + // "trap", + "untrace_var", + "warn" + // "yield_self", + ) + +} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImplicitRequirePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImplicitRequirePass.scala new file mode 100644 index 000000000000..2bf252eccd21 --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImplicitRequirePass.scala @@ -0,0 +1,196 @@ +package io.joern.x2cpg.frontendspecific.rubysrc2cpg + +import io.joern.x2cpg.Defines +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.Constants.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, EdgeTypes, Operators} +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} +import org.apache.commons.text.CaseUtils + +import java.util.regex.Pattern +import scala.annotation.tailrec +import scala.collection.mutable + +/** A tuple holding the (name, importPath) for types in the analysis. + */ +case class TypeImportInfo(name: String, importPath: String) + +/** In some Ruby frameworks, it is common to have an autoloader library that implicitly loads requirements onto the + * stack. This pass makes these imports explicit. The most popular one is Zeitwerk which we check in `Gemsfile.lock` to enable this pass. + * + * @param externalTypes + * a list of additional types to consider that may be importable but are not in the CPG. + */ +class ImplicitRequirePass(cpg: Cpg, externalTypes: Seq[TypeImportInfo] = Nil) + extends ForkJoinParallelCpgPass[Method](cpg) { + + /** A tuple holding information about the type import info, additionally with a boolean indicating if it is external + * or not. + */ + private case class TypeImportInfoWithProvidence(info: TypeImportInfo, isExternal: Boolean) + private val typeNameToImportInfo = mutable.Map.empty[String, Seq[TypeImportInfoWithProvidence]] + + private val Require: String = "require" + private val Self: String = "self" + private val Initialize: String = "initialize" + private val Clazz: String = "" + + override def init(): Unit = { + val importableTypeInfo = cpg.typeDecl + .isExternal(false) + .filter { typeDecl => + // zeitwerk will match types that share the name of the path. + // This match is insensitive to camel case, i.e, foo_bar will match type FooBar. + val fileName = typeDecl.filename.split(Array('/', '\\')).last + val typeName = typeDecl.name + ImplicitRequirePass.isAutoloadable(typeName, fileName) + } + .map { typeDecl => + val typeImportInfo = TypeImportInfo(typeDecl.name, ImplicitRequirePass.normalizePath(typeDecl.filename)) + TypeImportInfoWithProvidence(typeImportInfo, typeDecl.isExternal) + } + .l + // Group types by symbol and add to map for quicker retrieval later + typeNameToImportInfo.addAll(importableTypeInfo.groupBy { case TypeImportInfoWithProvidence(typeImportInfo, _) => + typeImportInfo.name + }) + typeNameToImportInfo.addAll(externalTypes.map(TypeImportInfoWithProvidence(_, true)).groupBy { + case TypeImportInfoWithProvidence(typeImportInfo, _) => typeImportInfo.name + }) + } + + private def getFieldBaseFromString(fieldAccessString: String): String = { + val normalizedFieldAccessString = fieldAccessString.replace("::", ".") + normalizedFieldAccessString.split('.').headOption.getOrElse(normalizedFieldAccessString) + } + + override def generateParts(): Array[Method] = + cpg.method.isModule.whereNot(_.astChildren.isCall.nameExact(Require)).toArray + + /** Collects methods within a module. + */ + private def findMethodsViaAstChildren(module: Method): Iterator[Method] = { + // TODO For now we have to go via the full name regex because the AST is not yet linked + // at the execution time of this pass. + // Iterator(module) ++ module.astChildren.flatMap { + // case x: TypeDecl => x.method.flatMap(findMethodsViaAstChildren) + // case x: Method => Iterator(x) ++ x.astChildren.collectAll[Method].flatMap(findMethodsViaAstChildren) + // case _ => Iterator.empty + // } + cpg.method.fullName(Pattern.quote(module.fullName) + ".*") + } + + override def runOnPart(builder: DiffGraphBuilder, moduleMethod: Method): Unit = { + val possiblyImportedSymbols = mutable.ArrayBuffer.empty[String] + val currPath = ImplicitRequirePass.normalizePath(moduleMethod.filename) + + val typeDecl = cpg.typeDecl.fullName(Pattern.quote(moduleMethod.fullName) + ".*").l + typeDecl.inheritsFromTypeFullName + .filterNot(_.endsWith(Clazz)) + .map(getFieldBaseFromString) + .foreach(possiblyImportedSymbols.append) + + val methodsOfModule = findMethodsViaAstChildren(moduleMethod).toList + val callsOfModule = methodsOfModule.ast.isCall.toList + + val symbolsGatheredFromCalls = callsOfModule + .flatMap { + case x if x.name == Initialize => + x.receiver.headOption.flatMap { + case x: TypeRef => Option(getFieldBaseFromString(x.code)) + case x: Identifier => Option(x.name) + case x: Call if x.name == Operators.fieldAccess => + Option(fieldAccessBase(x.asInstanceOf[FieldAccess])) + case _ => None + }.iterator + case x if x.methodFullName == Operators.fieldAccess => + fieldAccessBase(x.asInstanceOf[FieldAccess]) :: Nil + case _ => + Iterator.empty + } + .filterNot(_.isBlank) + + possiblyImportedSymbols.appendAll(symbolsGatheredFromCalls) + + var currOrder = moduleMethod.block.astChildren.size + possiblyImportedSymbols.distinct + .flatMap { identifierName => + typeNameToImportInfo + .getOrElse(identifierName, Seq.empty) + .sortBy { case TypeImportInfoWithProvidence(_, isExternal) => + isExternal // sorting booleans puts false (internal) first + } + .collectFirst { + // ignore an import to a file that defines the type + case TypeImportInfoWithProvidence(TypeImportInfo(_, importPath), _) if importPath != currPath => importPath + } + } + .distinct + .foreach { importPath => + val requireCall = createRequireCall(builder, importPath) + requireCall.order(currOrder) + builder.addEdge(moduleMethod.block, requireCall, EdgeTypes.AST) + currOrder += 1 + } + } + + private def createRequireCall(builder: DiffGraphBuilder, path: String): NewCall = { + val requireCallNode = NewCall() + .name(Require) + .code(s"$Require '$path'") + .methodFullName(s"$kernelPrefix.$Require") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + builder.addNode(requireCallNode) + // Create literal argument + val pathLiteralNode = + NewLiteral().code(s"'$path'").typeFullName(s"$builtinPrefix.String").argumentIndex(1).order(2) + builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.AST) + builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.ARGUMENT) + requireCallNode + } + + private def fieldAccessBase(fa: FieldAccess): String = fieldAccessParts(fa).headOption.getOrElse(fa.argument(1).code) + + @tailrec + private def fieldAccessParts(fa: FieldAccess): Seq[String] = { + fa.argument(1) match { + case subFa: Call if subFa.name == Operators.fieldAccess => fieldAccessParts(subFa.asInstanceOf[FieldAccess]) + case self: Identifier if self.name == Self => fa.fieldIdentifier.map(_.canonicalName).toSeq + case assignCall: Call if assignCall.name == Operators.assignment => + val assign = assignCall.asInstanceOf[Assignment] + // Handle the tmp var assign of qualified names + (assign.target, assign.source) match { + case (lhs: Identifier, rhs: Call) if lhs.name.startsWith(" + fieldAccessParts(rhs.asInstanceOf[FieldAccess]) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + +} + +object ImplicitRequirePass { + + /** Determines if the given type name and its corresponding parent file name allow for the type to be autoloaded by + * zeitwerk. + * @return + * true if the type is autoloadable from the given filename. + */ + def isAutoloadable(typeName: String, fileName: String): Boolean = { + // We use lowercase as something like `openssl` and `OpenSSL` don't give obvious clues where capitalisation occurs + val strippedFileName = normalizePath(fileName).toLowerCase + val lowerCaseTypeName = typeName.toLowerCase + lowerCaseTypeName == strippedFileName.toLowerCase || lowerCaseTypeName == CaseUtils + .toCamelCase(strippedFileName, true, '_', '-') + .toLowerCase + } + + private def normalizePath(path: String): String = path.replace("\\", "/").stripSuffix(".rb") + +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportsPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImportsPass.scala similarity index 71% rename from joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportsPass.scala rename to joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImportsPass.scala index 8467d44ee88f..ea8a427e426d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportsPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/ImportsPass.scala @@ -1,4 +1,4 @@ -package io.joern.rubysrc2cpg.passes +package io.joern.x2cpg.frontendspecific.rubysrc2cpg import io.joern.x2cpg.Imports.createImportNodeAndLink import io.joern.x2cpg.X2Cpg.stripQuotes @@ -9,9 +9,9 @@ import io.shiftleft.semanticcpg.language.* class ImportsPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { - private val importCallName: Seq[String] = Seq("require", "load", "require_relative", "require_all") - - override def generateParts(): Array[Call] = cpg.call.nameExact(importCallName*).toArray + override def generateParts(): Array[Call] = { + cpg.call.nameExact(ImportsPass.ImportCallNames.toSeq*).isStatic.toArray + } override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = { val importedEntity = stripQuotes(call.argument.isLiteral.code.l match { @@ -22,3 +22,7 @@ class ImportsPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { if (call.name == "require_all") importNode.isWildcard(true) } } + +object ImportsPass { + val ImportCallNames: Set[String] = Set("require", "load", "require_relative", "require_all") +} diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyImportResolverPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyImportResolverPass.scala similarity index 90% rename from joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyImportResolverPass.scala rename to joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyImportResolverPass.scala index f52733204573..298d1e94b976 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyImportResolverPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyImportResolverPass.scala @@ -1,14 +1,12 @@ -package io.joern.rubysrc2cpg.passes +package io.joern.x2cpg.frontendspecific.rubysrc2cpg import better.files.File -import io.joern.rubysrc2cpg.deprecated.utils.PackageTable -import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.semanticcpg.language.importresolver.* +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.Constants.* import io.joern.x2cpg.passes.frontend.XImportResolverPass import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.semanticcpg.language.* -import io.joern.rubysrc2cpg.passes.Defines as RDefines +import io.shiftleft.semanticcpg.language.importresolver.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import java.io.File as JFile @@ -47,7 +45,7 @@ class RubyImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { .flatMap(fullName => Seq( ResolvedTypeDecl(fullName), - ResolvedMethod(s"$fullName.${Defines.Initialize}", "new", fullName.split("[.]").lastOption) + ResolvedMethod(s"$fullName.${Initialize}", "new", fullName.split("[.]").lastOption) ) ) .toSet @@ -61,7 +59,7 @@ class RubyImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { // Expose methods which are directly present in a file, without any module, TypeDecl val resolvedMethods = cpg.method .where(_.file.name(filePattern)) - .where(_.nameExact(RDefines.Program)) + .where(_.nameExact(Main)) .astChildren .astChildren .isMethod diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeHintCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeHintCallLinker.scala similarity index 84% rename from joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeHintCallLinker.scala rename to joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeHintCallLinker.scala index 333245e40dda..5244b021fb4a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeHintCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeHintCallLinker.scala @@ -1,6 +1,7 @@ -package io.joern.rubysrc2cpg.passes +package io.joern.x2cpg.frontendspecific.rubysrc2cpg import io.joern.x2cpg.passes.frontend.XTypeHintCallLinker +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.Constants.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, NewMethod} import io.shiftleft.semanticcpg.language.* @@ -26,8 +27,8 @@ class RubyTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { } val name = if (methodName.contains(pathSep) && methodName.length > methodName.lastIndexOf(pathSep) + 1) - val strippedMethod = methodName.stripPrefix(s"${GlobalTypes.kernelPrefix}:") - if GlobalTypes.kernelFunctions.contains(strippedMethod) then strippedMethod + val strippedMethod = methodName.stripPrefix(s"$kernelPrefix.") + if kernelFunctions.contains(strippedMethod) then strippedMethod else methodName.substring(methodName.lastIndexOf(pathSep) + pathSep.length) else methodName createMethodStub(name, methodName, call.argumentOut.size, isExternal, builder) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPassGenerator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeRecoveryPassGenerator.scala similarity index 83% rename from joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPassGenerator.scala rename to joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeRecoveryPassGenerator.scala index a9afaf7b9d67..1171a97e1e9a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPassGenerator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/rubysrc2cpg/RubyTypeRecoveryPassGenerator.scala @@ -1,13 +1,13 @@ -package io.joern.rubysrc2cpg.passes +package io.joern.x2cpg.frontendspecific.rubysrc2cpg import io.joern.x2cpg.Defines as XDefines +import io.joern.x2cpg.frontendspecific.rubysrc2cpg.Constants.* import io.joern.x2cpg.passes.frontend.* import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt -import io.shiftleft.codepropertygraph.generated.{Cpg, Operators, PropertyNames} import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder, Operators, PropertyNames} +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess -import io.shiftleft.semanticcpg.language.{types, *} -import overflowdb.BatchedUpdate.DiffGraphBuilder class RubyTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPassGenerator[File](cpg, config) { @@ -20,6 +20,8 @@ private class RubyTypeRecovery(cpg: Cpg, state: XTypeRecoveryState, iteration: I override def compilationUnits: Iterator[File] = cpg.file.iterator + override def isParallel: Boolean = false + override def generateRecoveryForCompilationUnitTask( unit: File, builder: DiffGraphBuilder @@ -40,7 +42,7 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, /** A heuristic method to determine if a call name is a constructor or not. */ override protected def isConstructor(name: String): Boolean = - !name.isBlank && (name == "new" || name == Defines.Initialize) + !name.isBlank && (name == "new" || name == Initialize) override protected def hasTypes(node: AstNode): Boolean = node match { case x: Call if !x.methodFullName.startsWith("") => @@ -55,9 +57,14 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, case x @ (_: Identifier | _: Local | _: MethodParameterIn) => symbolTable.append(x, x.getKnownTypes) case call: Call => val tnfs = - if call.methodFullName == XDefines.DynamicCallUnknownFullName || call.methodFullName.startsWith("") - then (call.dynamicTypeHintFullName ++ call.possibleTypes).distinct - else (call.methodFullName +: (call.dynamicTypeHintFullName ++ call.possibleTypes)).distinct + if ( + call.name != "initialize" && (call.methodFullName == XDefines.DynamicCallUnknownFullName || call.methodFullName + .startsWith("")) + ) { + (call.dynamicTypeHintFullName ++ call.possibleTypes).distinct + } else { + (call.methodFullName +: (call.dynamicTypeHintFullName ++ call.possibleTypes)).distinct + } symbolTable.append(call, tnfs.toSet) case _ => @@ -96,7 +103,7 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, else fieldAccessParents .filter(_.endsWith(fieldAccessName.stripSuffix(s".${c.name}"))) - .map(x => s"$x:${c.name}") + .map(x => s"$x.${c.name}") } else { types } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/SwiftTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/SwiftTypeRecovery.scala index a40fd813531c..50f1f2bc8fcc 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/SwiftTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/SwiftTypeRecovery.scala @@ -5,10 +5,10 @@ import io.joern.x2cpg.Defines.ConstructorMethodName import io.joern.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Operators, Properties, PropertyNames} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class SwiftTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPassGenerator[File](cpg, config) { @@ -47,9 +47,9 @@ private class RecoverForSwiftFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match { case x @ (_: Identifier | _: Local | _: MethodParameterIn) - if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => - val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) - val typeHints = symbolTable.get(LocalVar(x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any))) - typeFullName + if x.propertyOption(Properties.TypeFullName).getOrElse(Defines.Any) != Defines.Any => + val typeFullName = x.propertyOption(Properties.TypeFullName).getOrElse(Defines.Any) + val typeHints = symbolTable.get(LocalVar(typeFullName)) - typeFullName lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet val resolvedTypeHints = if (typeHints.nonEmpty) symbolTable.put(x, typeHints) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/Base.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/Base.scala index a14daa5f8ad5..a23a71fb4f5b 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/Base.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/Base.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.passes.CpgPassBase import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -import io.joern.x2cpg.passes.base._ +import io.joern.x2cpg.passes.base.* object Base { val overlayName: String = "base" @@ -31,11 +31,7 @@ class Base extends LayerCreator { override val description: String = Base.description override def create(context: LayerCreatorContext): Unit = { - val cpg = context.cpg - cpg.graph.indexManager.createNodePropertyIndex(PropertyNames.FULL_NAME) - Base.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, index) - } + Base.passes(context.cpg).foreach(_.createAndApply()) } // LayerCreators need one-arg constructor, because they're called by reflection from io.joern.console.Run diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/CallGraph.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/CallGraph.scala index 45ec40470400..b821a148f623 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/CallGraph.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/CallGraph.scala @@ -22,10 +22,7 @@ class CallGraph extends LayerCreator { override val dependsOn: List[String] = List(TypeRelations.overlayName) override def create(context: LayerCreatorContext): Unit = { - val cpg = context.cpg - CallGraph.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, index) - } + CallGraph.passes(context.cpg).foreach(_.createAndApply()) } // LayerCreators need one-arg constructor, because they're called by reflection from io.joern.console.Run diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/ControlFlow.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/ControlFlow.scala index d0cdb3b1de5c..c19dcc9bb704 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/ControlFlow.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/ControlFlow.scala @@ -3,7 +3,7 @@ package io.joern.x2cpg.layers import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.passes.CpgPassBase -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} import io.joern.x2cpg.passes.controlflow.CfgCreationPass import io.joern.x2cpg.passes.controlflow.cfgdominator.CfgDominatorPass @@ -31,10 +31,7 @@ class ControlFlow extends LayerCreator { override val dependsOn: List[String] = List(Base.overlayName) override def create(context: LayerCreatorContext): Unit = { - val cpg = context.cpg - ControlFlow.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, index) - } + ControlFlow.passes(context.cpg).foreach(_.createAndApply()) } // LayerCreators need one-arg constructor, because they're called by reflection from io.joern.console.Run diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpAst.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpAst.scala index cbbe9ceaf31d..cf06abf67ed7 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpAst.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpAst.scala @@ -1,7 +1,7 @@ package io.joern.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class AstDumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCdg.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCdg.scala index 2528d7107758..6e92efc27ac8 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCdg.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCdg.scala @@ -1,7 +1,7 @@ package io.joern.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class CdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCfg.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCfg.scala index 1f78125b4177..0b3000c31fc8 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCfg.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/DumpCfg.scala @@ -1,7 +1,7 @@ package io.joern.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class CfgDumpOptions(var outDir: String) extends LayerCreatorOptions {} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala index 82fc0d70dadd..e143138c88b5 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala @@ -20,10 +20,7 @@ class TypeRelations extends LayerCreator { override val dependsOn: List[String] = List(Base.overlayName) override def create(context: LayerCreatorContext): Unit = { - val cpg = context.cpg - TypeRelations.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, index) - } + TypeRelations.passes(context.cpg).foreach(_.createAndApply()) } // Layers need one-arg constructor, because they're called by reflection from io.joern.console.Run diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ContainsEdgePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ContainsEdgePass.scala index 703607803eee..53d16ff44297 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ContainsEdgePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ContainsEdgePass.scala @@ -1,12 +1,13 @@ package io.joern.x2cpg.passes.base import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* /** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends which do not provide * method stubs and type decl stubs. @@ -15,7 +16,7 @@ class ContainsEdgePass(cpg: Cpg) extends ForkJoinParallelCpgPass[AstNode](cpg) { import ContainsEdgePass._ override def generateParts(): Array[AstNode] = - cpg.graph.nodes(sourceTypes*).asScala.map(_.asInstanceOf[AstNode]).toArray + cpg.graph.nodes(sourceTypes*).cast[AstNode].toArray override def runOnPart(dstGraph: DiffGraphBuilder, source: AstNode): Unit = { // AST is assumed to be a tree. If it contains cycles, then this will give a nice endless loop with OOM diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/FileCreationPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/FileCreationPass.scala index a045ef5971e3..b5bef4d0c258 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/FileCreationPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/FileCreationPass.scala @@ -1,9 +1,8 @@ package io.joern.x2cpg.passes.base import io.joern.x2cpg.utils.LinkingUtil -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.{NewFile, StoredNode} -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.nodes.{File, NewFile, StoredNode} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.FileTraversal @@ -25,7 +24,7 @@ class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { } def createFileIfDoesNotExist(srcNode: StoredNode, destFullName: String): Unit = { - if (destFullName != srcNode.propertyDefaultValue(PropertyNames.FILENAME)) { + if (destFullName != File.PropertyDefaults.Name) { val dstFullName = if (destFullName == "") { FileTraversal.UNKNOWN } else { destFullName } val newFile = newFileNameToNode.getOrElseUpdate( @@ -42,7 +41,7 @@ class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { // Create SOURCE_FILE edges from nodes of various types to FILE linkToSingle( cpg, - srcNodes = cpg.graph.nodes(srcLabels*).toList, + srcNodes = cpg.graph.nodes(srcLabels*).cast[StoredNode].toList, srcLabels = srcLabels, dstNodeLabel = NodeTypes.FILE, edgeType = EdgeTypes.SOURCE_FILE, @@ -50,6 +49,7 @@ class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { originalFileNameToNode.get(x) }, dstFullNameKey = PropertyNames.FILENAME, + dstDefaultPropertyValue = File.PropertyDefaults.Name, dstGraph, Some(createFileIfDoesNotExist) ) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodDecoratorPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodDecoratorPass.scala index 99b19f36d7a4..19f2bf802a87 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodDecoratorPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodDecoratorPass.scala @@ -3,7 +3,7 @@ package io.joern.x2cpg.passes.base import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} /** Adds a METHOD_PARAMETER_OUT for each METHOD_PARAMETER_IN to the graph and connects those with a PARAMETER_LINK edge. diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodStubCreator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodStubCreator.scala index 7fa2f42016e5..794127e77402 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodStubCreator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/MethodStubCreator.scala @@ -3,12 +3,11 @@ package io.joern.x2cpg.passes.base import io.joern.x2cpg.Defines import io.joern.x2cpg.passes.base.MethodStubCreator.createMethodStub import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, EvaluationStrategies, NodeTypes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder import scala.collection.mutable import scala.util.Try @@ -24,7 +23,7 @@ class MethodStubCreator(cpg: Cpg) extends CpgPass(cpg) { private val methodFullNameToNode = mutable.LinkedHashMap[String, Method]() private val methodToParameterCount = mutable.LinkedHashMap[CallSummary, Int]() - override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = { + override def run(dstGraph: DiffGraphBuilder): Unit = { for (method <- cpg.method) { methodFullNameToNode.put(method.fullName, method) } @@ -121,7 +120,7 @@ object MethodStubCreator { val blockNode = NewBlock() .order(1) .argumentIndex(1) - .typeFullName("ANY") + .typeFullName(Defines.Any) dstGraph.addNode(blockNode) dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/NamespaceCreator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/NamespaceCreator.scala index 9d6724e7a607..417c8b24afbf 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/NamespaceCreator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/NamespaceCreator.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.NewNamespace import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. * diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ParameterIndexCompatPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ParameterIndexCompatPass.scala index a5a3dc2ae508..9323da89cf01 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ParameterIndexCompatPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/ParameterIndexCompatPass.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn.PropertyDefaults import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** Old CPGs use the `order` field to indicate the parameter index while newer CPGs use the `parameterIndex` field. This * pass checks whether `parameterIndex` is not set, in which case the value of `order` is copied over. diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeDeclStubCreator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeDeclStubCreator.scala index 778c9cd42a9e..7d658d7efb2f 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeDeclStubCreator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeDeclStubCreator.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.{NewTypeDecl, TypeDeclBase} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, NamespaceTraversal} /** This pass has no other pass as prerequisite. For each `TYPE` node that does not have a corresponding `TYPE_DECL` diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeEvalPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeEvalPass.scala index 00547bf2142d..b51a934c0e98 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeEvalPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeEvalPass.scala @@ -1,14 +1,12 @@ package io.joern.x2cpg.passes.base import io.joern.x2cpg.utils.LinkingUtil -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.nodes.{Local, StoredNode} import io.shiftleft.passes.ForkJoinParallelCpgPass -import overflowdb.Node -import overflowdb.traversal.* - -class TypeEvalPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[Node]](cpg) with LinkingUtil { +import io.shiftleft.semanticcpg.language.* +class TypeEvalPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[StoredNode]](cpg) with LinkingUtil { private val srcLabels = List( NodeTypes.METHOD_PARAMETER_IN, NodeTypes.METHOD_PARAMETER_OUT, @@ -24,11 +22,11 @@ class TypeEvalPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[Node]](cpg) wi NodeTypes.UNKNOWN ) - def generateParts(): Array[List[Node]] = { - cpg.graph.nodes(srcLabels*).toList.grouped(MAX_BATCH_SIZE).toArray + def generateParts(): Array[List[StoredNode]] = { + cpg.graph.nodes(srcLabels*).cast[StoredNode].toList.grouped(MAX_BATCH_SIZE).toArray } - def runOnPart(builder: DiffGraphBuilder, part: List[overflowdb.Node]): Unit = { + def runOnPart(builder: DiffGraphBuilder, part: List[StoredNode]): Unit = { linkToSingle( cpg = cpg, srcNodes = part, @@ -37,6 +35,7 @@ class TypeEvalPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[Node]](cpg) wi edgeType = EdgeTypes.EVAL_TYPE, dstNodeMap = typeFullNameToNode(cpg, _), dstFullNameKey = PropertyNames.TYPE_FULL_NAME, + dstDefaultPropertyValue = Local.PropertyDefaults.TypeFullName, dstGraph = builder, None ) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeRefPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeRefPass.scala index 65dbae189c28..d851b16e201d 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeRefPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/base/TypeRefPass.scala @@ -1,21 +1,19 @@ package io.joern.x2cpg.passes.base import io.joern.x2cpg.utils.LinkingUtil -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.nodes.{Type, StoredNode} import io.shiftleft.passes.ForkJoinParallelCpgPass -import overflowdb.Node -import overflowdb.traversal.* - -class TypeRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[Node]](cpg) with LinkingUtil { +import io.shiftleft.semanticcpg.language.* +class TypeRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[StoredNode]](cpg) with LinkingUtil { private val srcLabels = List(NodeTypes.TYPE) - def generateParts(): Array[List[Node]] = { - cpg.graph.nodes(srcLabels*).toList.grouped(MAX_BATCH_SIZE).toArray + def generateParts(): Array[List[StoredNode]] = { + cpg.graph.nodes(srcLabels*).cast[StoredNode].toList.grouped(MAX_BATCH_SIZE).toArray } - def runOnPart(builder: DiffGraphBuilder, part: List[overflowdb.Node]): Unit = { + def runOnPart(builder: DiffGraphBuilder, part: List[StoredNode]): Unit = { linkToSingle( cpg = cpg, srcNodes = part, @@ -24,6 +22,7 @@ class TypeRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[List[Node]](cpg) wit edgeType = EdgeTypes.REF, dstNodeMap = typeDeclFullNameToNode(cpg, _), dstFullNameKey = PropertyNames.TYPE_DECL_FULL_NAME, + dstDefaultPropertyValue = Type.PropertyDefaults.TypeDeclFullName, dstGraph = builder, None ) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/DynamicCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/DynamicCallLinker.scala index e12539269822..ed3496e7e94c 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/DynamicCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/DynamicCallLinker.scala @@ -2,15 +2,14 @@ package io.joern.x2cpg.passes.callgraph import io.joern.x2cpg.Defines.DynamicCallUnknownFullName import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method, TypeDecl} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method, StoredNode, Type, TypeDecl} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} -import overflowdb.{NodeDb, NodeRef} import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* /** We compute the set of possible call-targets for each dynamic call, and add them as CALL edges to the graph, based on * call.methodFullName, method.name and method.signature, the inheritance hierarchy and the AST of typedecls and @@ -56,10 +55,10 @@ class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg) { initMaps() // ValidM maps class C and method name N to the set of // func ptrs implementing N for C and its subclasses - for ( - typeDecl <- cpg.typeDecl; + for { + typeDecl <- cpg.typeDecl method <- typeDecl._methodViaAstOut - ) { + } { val methodName = method.fullName val candidates = allSubclasses(typeDecl.fullName).flatMap { staticLookup(_, method) } validM.put(methodName, candidates) @@ -114,8 +113,8 @@ class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg) { if (visitedNodes.contains(cur)) return visitedNodes visitedNodes.addOne(cur) - (if (inSuperDirection) cpg.typeDecl.fullNameExact(cur.fullName).flatMap(_.inheritsFromOut.referencedTypeDecl) - else cpg.typ.fullNameExact(cur.fullName).flatMap(_.inheritsFromIn)) + (if (inSuperDirection) cpg.typeDecl.fullNameExact(cur.fullName)._typeViaInheritsFromOut.referencedTypeDecl + else cpg.typ.fullNameExact(cur.fullName).inheritsFromIn) .collectAll[TypeDecl] .to(mutable.LinkedHashSet) match { case classesToEval if classesToEval.isEmpty => visitedNodes @@ -174,16 +173,8 @@ class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg) { validM.get(call.methodFullName) match { case Some(tgts) => - val callsOut = call.callOut.fullName.toSetImmutable - val tgtMs = tgts - .flatMap(destMethod => - if (cpg.graph.indexManager.isIndexed(PropertyNames.FULL_NAME)) { - methodFullNameToNode(destMethod) - } else { - cpg.method.fullNameExact(destMethod).headOption - } - ) - .toSet + val callsOut = call._callOut.cast[Method].fullName.toSetImmutable + val tgtMs = tgts.flatMap(destMethod => methodFullNameToNode(destMethod)).toSet // Non-overridden methods linked as external stubs should be excluded if they are detected val (externalMs, internalMs) = tgtMs.partition(_.isExternal) (if (externalMs.nonEmpty && internalMs.nonEmpty) internalMs else tgtMs) @@ -209,8 +200,8 @@ class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg) { } } - private def nodesWithFullName(x: String): Iterable[NodeRef[? <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + private def nodesWithFullName(x: String): Iterator[StoredNode] = + cpg.graph.nodesWithProperty(PropertyNames.FULL_NAME, x).cast[StoredNode] private def methodFullNameToNode(x: String): Option[Method] = nodesWithFullName(x).collectFirst { case x: Method => x } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/MethodRefLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/MethodRefLinker.scala index 86174f9a872d..e0411f8dead4 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/MethodRefLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/MethodRefLinker.scala @@ -1,28 +1,27 @@ package io.joern.x2cpg.passes.callgraph import io.joern.x2cpg.utils.LinkingUtil -import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.passes.CpgPass -import overflowdb.traversal.* +import io.shiftleft.semanticcpg.language.* /** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends which do not provide * method stubs and type decl stubs. */ class MethodRefLinker(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { - private val srcLabels = List(NodeTypes.METHOD_REF) - override def run(dstGraph: DiffGraphBuilder): Unit = { // Create REF edges from METHOD_REFs to METHOD linkToSingle( cpg, - srcNodes = cpg.graph.nodes(srcLabels*).toList, + srcNodes = cpg.methodRef.l, srcLabels = List(NodeTypes.METHOD_REF), dstNodeLabel = NodeTypes.METHOD, edgeType = EdgeTypes.REF, dstNodeMap = methodFullNameToNode(cpg, _), dstFullNameKey = PropertyNames.METHOD_FULL_NAME, + dstDefaultPropertyValue = Method.PropertyDefaults.FullName, dstGraph, None ) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/NaiveCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/NaiveCallLinker.scala index e9156f217b49..f9f1332af728 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/NaiveCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/callgraph/NaiveCallLinker.scala @@ -3,8 +3,7 @@ package io.joern.x2cpg.passes.callgraph import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.jIteratortoTraversal +import io.shiftleft.semanticcpg.language.* /** Link remaining unlinked calls to methods only by their name (not full name) * @param cpg diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/CfgCreationPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/CfgCreationPass.scala index f1dab9252fb7..cf2ce37a0f43 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/CfgCreationPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/CfgCreationPass.scala @@ -3,7 +3,7 @@ package io.joern.x2cpg.passes.controlflow import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.passes.controlflow.cfgcreation.CfgCreator /** A pass that creates control flow graphs from abstract syntax trees. diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala index ac157243ac15..f59b5776ffd8 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.passes.controlflow.cfgcreation.Cfg.CfgEdgeType import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, EdgeTypes, Operators} import io.shiftleft.semanticcpg.language.* -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder /** Translation of abstract syntax trees into control flow graphs * @@ -174,10 +174,18 @@ class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder) { cfgForChildren(node) case ControlStructureTypes.MATCH => cfgForMatchExpression(node) + case ControlStructureTypes.THROW => + cfgForThrowStatement(node) case _ => Cfg.empty } + protected def cfgForThrowStatement(node: ControlStructure): Cfg = { + val throwExprCfg = node.astChildren.find(_.order == 1).map(cfgFor).getOrElse(Cfg.empty) + val concatedNatedCfg = throwExprCfg ++ Cfg(entryNode = Option(node)) + concatedNatedCfg.copy(edges = concatedNatedCfg.edges ++ singleEdge(node, exitNode)) + } + /** The CFG for a break/continue statements contains only the break/continue statement as a single entry node. The * fringe is empty, that is, appending another CFG to the break statement will not result in the creation of an edge * from the break statement to the entry point of the other CFG. Labeled breaks are treated like gotos and are added @@ -368,16 +376,17 @@ class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder) { val loopExprCfg = children.find(_.order == nLocals + 3).map(cfgFor).getOrElse(Cfg.empty) val bodyCfg = children.find(_.order == nLocals + 4).map(cfgFor).getOrElse(Cfg.empty) - val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg - val entryNode = (initExprCfg ++ innerCfg).entryNode + val innerCfg = bodyCfg ++ loopExprCfg + val loopEntryNode = conditionCfg.entryNode.orElse(innerCfg.entryNode) + val entryNode = initExprCfg.entryNode.orElse(loopEntryNode) - val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { + val newEdges = edgesFromFringeTo(initExprCfg, loopEntryNode) ++ + edgesFromFringeTo(innerCfg, loopEntryNode) ++ + edgesFromFringeTo(conditionCfg, innerCfg.entryNode.orElse(conditionCfg.entryNode), TrueEdge) ++ { if (loopExprCfg.entryNode.isDefined) { edges(takeCurrentLevel(bodyCfg.continues), loopExprCfg.entryNode) } else { - edges(takeCurrentLevel(bodyCfg.continues), innerCfg.entryNode) + edges(takeCurrentLevel(bodyCfg.continues), loopEntryNode) } } @@ -385,7 +394,7 @@ class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder) { .from(initExprCfg, conditionCfg, loopExprCfg, bodyCfg) .copy( entryNode = entryNode, - edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, + edges = newEdges ++ initExprCfg.edges ++ conditionCfg.edges ++ innerCfg.edges, fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(bodyCfg.breaks).map((_, AlwaysEdge)), breaks = reduceAndFilterLevel(bodyCfg.breaks), continues = reduceAndFilterLevel(bodyCfg.continues) @@ -461,18 +470,32 @@ class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder) { val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ edgesFromFringeTo(conditionCfg, falseCfg.entryNode) - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = trueCfg.fringe ++ { + val ifStatementFringe = + if (trueCfg.entryNode.isEmpty && falseCfg.entryNode.isEmpty) { + conditionCfg.fringe.withEdgeType(AlwaysEdge) + } else { + val trueFringe = if (trueCfg.entryNode.isDefined) { + trueCfg.fringe + } else { + conditionCfg.fringe.withEdgeType(TrueEdge) + } + + val falseFringe = if (falseCfg.entryNode.isDefined) { falseCfg.fringe } else { conditionCfg.fringe.withEdgeType(FalseEdge) } - } + + trueFringe ++ falseFringe + } + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, + fringe = ifStatementFringe ) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala index e0823e835ee7..b05e62aee1be 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala @@ -17,7 +17,7 @@ class CfgDominatorFrontier[NodeType](cfgAdapter: CfgAdapter[NodeType], domTreeAd private def withIDom(x: NodeType, preds: Seq[NodeType]) = doms(x).map(i => (x, preds, i)) - def calculate(cfgNodes: Seq[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = { + def calculate(cfgNodes: Iterator[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = { val domFrontier = mutable.Map.empty[NodeType, mutable.Set[NodeType]] for { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala index 933ba72fdf2b..69e33e8773a0 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{Method, StoredNode} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala index 5bca67fd685f..c47915bf6d88 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala @@ -13,7 +13,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ Unknown } import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.passes.controlflow.cfgdominator.{CfgDominatorFrontier, ReverseCpgCfgAdapter} import org.slf4j.{Logger, LoggerFactory} @@ -22,6 +22,9 @@ import org.slf4j.{Logger, LoggerFactory} class CdgPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg) { import CdgPass.logger + // 10 is just an arbitrary number - we merely want to log 'a few times' but no more than that + val hasLogged = java.util.concurrent.atomic.AtomicInteger(10) + override def generateParts(): Array[Method] = cpg.method.toArray override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = { @@ -39,16 +42,12 @@ class CdgPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg) { case postDomFrontierNode => val nodeLabel = postDomFrontierNode.label val containsIn = postDomFrontierNode._containsIn - if (containsIn == null || !containsIn.hasNext) { - logger.warn(s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG.") - } else { - val method = containsIn.next() + // duplicate check looks (and is) superfluous, but it's a fastpath micro optimization + if (hasLogged.get() > 0 && hasLogged.decrementAndGet() > 0) { + val method = containsIn.nextOption().map(_.toString).getOrElse("N/A") logger.warn( - s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + - s" Method: ${method match { - case m: Method => m.fullName; - case other => other.label - }}" + + s"Found CDG edge starting at $nodeLabel node $node <-> ${postDomFrontierNode}. This is most likely caused by an invalid CFG." + + s" Method: ${method}" + s" number of outgoing CFG edges from $nodeLabel node: ${postDomFrontierNode._cfgOut.size}" ) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/Dereference.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/Dereference.scala deleted file mode 100644 index 2ee1c7ae355a..000000000000 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/Dereference.scala +++ /dev/null @@ -1,35 +0,0 @@ -package io.joern.x2cpg.passes.frontend - -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ - -object Dereference { - - def apply(cpg: Cpg): Dereference = cpg.metaData.language.headOption match { - case Some(Languages.NEWC) => CDereference() - case _ => DefaultDereference() - } - -} - -sealed trait Dereference { - - def dereferenceTypeFullName(fullName: String): String - -} - -case class CDereference() extends Dereference { - - /** Types from C/C++ can be annotated with * to indicate being a reference. As our CPG schema currently lacks a - * separate field for that information the * is part of the type full name and needs to be removed when linking. - */ - override def dereferenceTypeFullName(fullName: String): String = fullName.replace("*", "") - -} - -case class DefaultDereference() extends Dereference { - - override def dereferenceTypeFullName(fullName: String): String = fullName - -} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala index 179c31d4b6d1..5aef81da1a82 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala @@ -1,26 +1,22 @@ package io.joern.x2cpg.passes.frontend import io.joern.x2cpg.passes.frontend.TypeNodePass.fullToShortName -import io.shiftleft.codepropertygraph.generated.Cpg +import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.generated.{Cpg, Properties} import io.shiftleft.codepropertygraph.generated.nodes.NewType -import io.shiftleft.passes.{KeyPool, CpgPass} -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.codepropertygraph.generated.PropertyNames +import io.shiftleft.passes.CpgPass +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import scala.collection.mutable -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal /** Creates a `TYPE` node for each type in `usedTypes` as well as all inheritsFrom type names in the CPG * * Alternatively, set `getTypesFromCpg = true`. If this is set, the `registeredTypes` argument will be ignored. * Instead, type nodes will be created for every unique `TYPE_FULL_NAME` value in the CPG. */ -class TypeNodePass protected ( - registeredTypes: List[String], - cpg: Cpg, - keyPool: Option[KeyPool], - getTypesFromCpg: Boolean -) extends CpgPass(cpg, "types", keyPool) { +class TypeNodePass protected (registeredTypes: List[String], cpg: Cpg, getTypesFromCpg: Boolean) + extends CpgPass(cpg, "types") { protected def typeDeclTypes: mutable.Set[String] = { val typeDeclTypes = mutable.Set[String]() @@ -33,9 +29,8 @@ class TypeNodePass protected ( protected def typeFullNamesFromCpg: Set[String] = { cpg.all - .map(_.property(PropertyNames.TYPE_FULL_NAME)) + .map(_.property(Properties.TypeFullName)) .filter(_ != null) - .map(_.toString) .toSet } @@ -51,7 +46,9 @@ class TypeNodePass protected ( val usedTypesSet = typeDeclTypes ++ typeFullNameValues usedTypesSet.remove("") val usedTypes = - (usedTypesSet.filterInPlace(!_.endsWith(NamespaceTraversal.globalNamespaceName)).toArray :+ "ANY").toSet.sorted + (usedTypesSet + .filterInPlace(!_.endsWith(NamespaceTraversal.globalNamespaceName)) + .toArray :+ Defines.Any).toSet.sorted usedTypes.foreach { typeName => val shortName = fullToShortName(typeName) @@ -65,15 +62,20 @@ class TypeNodePass protected ( } object TypeNodePass { - def withTypesFromCpg(cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(Nil, cpg, keyPool, getTypesFromCpg = true) + def withTypesFromCpg(cpg: Cpg): TypeNodePass = { + new TypeNodePass(Nil, cpg, getTypesFromCpg = true) } - def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) + def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg): TypeNodePass = { + new TypeNodePass(registeredTypes, cpg, getTypesFromCpg = false) } def fullToShortName(typeName: String): String = { - typeName.takeWhile(_ != ':').split('.').lastOption.getOrElse(typeName) + if (typeName.endsWith(">")) { + // special case for typeFullName with generics as suffix + typeName.takeWhile(c => c != ':' && c != '<').split('.').lastOption.getOrElse(typeName) + } else { + typeName.takeWhile(_ != ':').split('.').lastOption.getOrElse(typeName) + } } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XConfigFileCreationPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XConfigFileCreationPass.scala index 4a496b1b2855..252bc2d28e01 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XConfigFileCreationPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XConfigFileCreationPass.scala @@ -4,7 +4,7 @@ import better.files.File import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.NewConfigFile import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.IOUtils import org.slf4j.LoggerFactory diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportsPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportsPass.scala index bb81e30b77ff..3d6ab6bbba53 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportsPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportsPass.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.Imports.createImportNodeAndLink import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment abstract class XImportsPass(cpg: Cpg) extends ForkJoinParallelCpgPass[(Call, Assignment)](cpg) { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XInheritanceFullNamePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XInheritanceFullNamePass.scala index 6f42f087cb0a..7a3adaf38617 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XInheritanceFullNamePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XInheritanceFullNamePass.scala @@ -2,7 +2,7 @@ package io.joern.x2cpg.passes.frontend import io.joern.x2cpg.passes.base.TypeDeclStubCreator import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.ForkJoinParallelCpgPass import io.shiftleft.semanticcpg.language.* @@ -41,7 +41,7 @@ abstract class XInheritanceFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPas inheritedTypes == Seq("ANY") || inheritedTypes == Seq("object") || inheritedTypes.isEmpty private def extractTypeDeclFromNode(node: AstNode): Option[String] = node match { - case x: Call if x.isCallForImportOut.nonEmpty => + case x: Call if x._isCallForImportOut.nonEmpty => x.isCallForImportOut.importedEntity.map { case imp if relativePathPattern.matcher(imp).matches() => imp.split(pathSep).toList match { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala index 7447a091f588..92733810ca93 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala @@ -4,9 +4,8 @@ import io.joern.x2cpg.passes.base.MethodStubCreator import io.joern.x2cpg.passes.frontend.XTypeRecovery.isDummyType import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.proto.cpg.Cpg.DispatchTypes import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern @@ -151,7 +150,7 @@ abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg) { name, fullName, "", - DispatchTypes.DYNAMIC_DISPATCH.name(), + DispatchTypes.DYNAMIC_DISPATCH, argSize, builder, isExternal, diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala index b82ed185384c..6d21e6e751fc 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala @@ -1,23 +1,21 @@ package io.joern.x2cpg.passes.frontend import io.joern.x2cpg.{Defines, X2CpgConfig} -import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, EdgeTypes, NodeTypes, Operators, PropertyNames} import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.passes.{CpgPass, CpgPassBase, ForkJoinParallelCpgPass} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.importresolver.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} import org.slf4j.{Logger, LoggerFactory} -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder -import scopt.OParser +import scopt.{DefaultOParserSetup, OParser} import java.util.regex.Pattern import scala.annotation.tailrec import scala.collection.mutable -import scala.util.{Failure, Success, Try} import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} /** @param iterations * the number of iterations to run. @@ -31,7 +29,14 @@ object XTypeRecoveryConfig { def parse(cmdLineArgs: Seq[String]): XTypeRecoveryConfig = { OParser - .parse(parserOptions, cmdLineArgs, XTypeRecoveryConfig()) + .parse( + parserOptions, + cmdLineArgs, + XTypeRecoveryConfig(), + new DefaultOParserSetup { + override def errorOnUnknownArgument = false + } + ) .getOrElse( throw new RuntimeException( s"unable to parse XTypeRecoveryConfig from commandline arguments ${cmdLineArgs.mkString(" ")}" @@ -96,8 +101,7 @@ class XTypeRecoveryState(val config: XTypeRecoveryConfig = XTypeRecoveryConfig() object XTypeRecoveryPassGenerator { private def linkMembersToTheirRefs(cpg: Cpg, builder: DiffGraphBuilder): Unit = { - import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromIteratorExt - import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt + import io.joern.x2cpg.passes.frontend.XTypeRecovery.{AllNodeTypesFromIteratorExt, AllNodeTypesFromNodeExt} def getFieldBaseTypes(fieldAccess: FieldAccess): Iterator[TypeDecl] = { fieldAccess @@ -151,7 +155,7 @@ abstract class XTypeRecoveryPassGenerator[CompilationUnitType <: AstNode]( if (postTypeRecoveryAndPropagation) res.append( new CpgPass(cpg): - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = { + override def run(builder: DiffGraphBuilder): Unit = { XTypeRecoveryPassGenerator.linkMembersToTheirRefs(cpg, builder) } ) @@ -257,17 +261,13 @@ object XTypeRecovery { */ def isDummyType(typ: String): Boolean = DummyTokens.exists(typ.contains) - @deprecated("please use XTypeRecoveryConfig.parserOptionsForParserConfig", since = "2.0.415") - def parserOptions[R <: X2CpgConfig[R] & TypeRecoveryParserConfig[R]]: OParser[?, R] = - XTypeRecoveryConfig.parserOptionsForParserConfig - // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of // the symbol table then perhaps this would work out better implicit class AllNodeTypesFromNodeExt(x: StoredNode) { def allTypes: Iterator[String] = - (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: - (x.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) - ++ x.property(PropertyNames.POSSIBLE_TYPES, Seq.empty))).iterator + (x.propertyOption(Properties.TypeFullName).getOrElse("ANY") +: + (x.property(Properties.DynamicTypeHintFullName) ++ + x.property(Properties.PossibleTypes))).iterator def getKnownTypes: Set[String] = { x.allTypes.toSet.filterNot(XTypeRecovery.unknownTypePattern.matches) @@ -301,8 +301,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( state: XTypeRecoveryState ) extends Runnable { - import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt - import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromIteratorExt + import io.joern.x2cpg.passes.frontend.XTypeRecovery.{AllNodeTypesFromIteratorExt, AllNodeTypesFromNodeExt} protected val logger: Logger = LoggerFactory.getLogger(getClass) @@ -436,7 +435,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( * @param a * assignment call pointer. */ - protected def visitAssignments(a: Assignment): Set[String] = visitAssignmentArguments(a.argumentOut.l) + protected def visitAssignments(a: Assignment): Set[String] = + visitAssignmentArguments(a.argumentOut.cast[CfgNode].l) protected def visitAssignmentArguments(args: List[AstNode]): Set[String] = args match { case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) @@ -555,7 +555,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( isFieldCache.getOrElseUpdate(i, isFieldUncached(i)) protected def isFieldUncached(i: Identifier): Boolean = - i.method.typeDecl.member.nameExact(i.name).nonEmpty + Try(i.method.typeDecl.member.nameExact(i.name).nonEmpty).getOrElse(false) /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field * which this method uses [[isField]] to determine. @@ -566,12 +566,19 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** Returns the appropriate field parent scope. */ protected def getFieldParents(fa: FieldAccess): Set[String] = { - val fieldName = getFieldName(fa).split(Pattern.quote(pathSep)).last - Try(cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet) match - case Failure(exception) => - logger.warn("Unable to obtain name of member's parent type declaration", exception) + getFieldName(fa).split(Pattern.quote(pathSep)).lastOption match { + case Some(fieldName) => + Try(cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet) match + case Failure(exception) => + logger.warn( + s"Unable to obtain name of member's parent type declaration: ${cpg.member.nameExact(fieldName).propertiesMap.mkString(",")}" + ) + Set.empty + case Success(typeDeclNames) => typeDeclNames + case None => + logger.warn(s"Unable to find a fieldName: ${debugLocation(fa)}") Set.empty - case Success(typeDeclNames) => typeDeclNames + } } /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field @@ -810,7 +817,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( case ::(_: TypeRef, ::(f: FieldIdentifier, _)) => f.canonicalName case xs => - logger.warn(s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}") + val debugInfo = xs.collect { case x: CfgNode => (x.label(), x.code) }.mkString(",") + logger.warn(s"Unhandled field structure $debugInfo @ ${debugLocation(fa)}") wrapName("") } } @@ -1231,7 +1239,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( lazy val existingTypes = storedNode.getKnownTypes val hasUnknownTypeFullName = storedNode - .property(PropertyNames.TYPE_FULL_NAME, Defines.Any) + .propertyOption(Properties.TypeFullName) + .getOrElse(Defines.Any) .matches(XTypeRecovery.unknownTypePattern.pattern.pattern()) if (types.nonEmpty && (hasUnknownTypeFullName || types.toSet != existingTypes)) { @@ -1270,10 +1279,12 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( */ protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = val hasUnknownType = - n.property(PropertyNames.TYPE_FULL_NAME, Defines.Any).matches(XTypeRecovery.unknownTypePattern.pattern.pattern()) + n.propertyOption(Properties.TypeFullName) + .getOrElse(Defines.Any) + .matches(XTypeRecovery.unknownTypePattern.pattern.pattern()) if (types.toSet != n.getKnownTypes || (hasUnknownType && types.nonEmpty)) { - setTypes(n, (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct) + setTypes(n, (n.propertyOption(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME).getOrElse(Seq.empty) ++ types).distinct) } /** If there is only 1 type hint then this is set to the `typeFullName` property and `dynamicTypeHintFullName` is diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/AliasLinkerPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/AliasLinkerPass.scala index 3714a9c3b6c1..6d9509df1d56 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/AliasLinkerPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/AliasLinkerPass.scala @@ -4,6 +4,7 @@ import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.TypeDecl import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.utils.LinkingUtil class AliasLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala index e60479b1be34..1d5bc0bc552d 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala @@ -1,13 +1,14 @@ package io.joern.x2cpg.passes.typerelations -import io.joern.x2cpg.passes.frontend.Dereference import io.joern.x2cpg.utils.LinkingUtil -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Member, StoredNode} import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.Call +import io.shiftleft.codepropertygraph.generated.nodes.Member +import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.operatorextension.{OpNodes, allFieldAccessTypes} -import io.shiftleft.semanticcpg.utils.MemberAccess +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes +import io.shiftleft.semanticcpg.language.operatorextension.allFieldAccessTypes import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.* @@ -67,18 +68,16 @@ class FieldAccessLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { dstFullNameKey: String, dstGraph: DiffGraphBuilder ): Unit = { - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels*).asScala.cast[SRC_NODE_TYPE].filterNot(_.outE(edgeType).hasNext).foreach { srcNode => + cpg.graph.nodes(srcLabels*).cast[SRC_NODE_TYPE].filterNot(_.outE(edgeType).hasNext).foreach { srcNode => if (!srcNode.outE(edgeType).hasNext) { getDstFullNames(srcNode).foreach { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match { + dstNodeMap(dstFullName) match { case Some(dstNode) => dstGraph.addEdge(srcNode, dstNode, edgeType) case None if dstNodeMap(dstFullName).isDefined => dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) case None => - logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) + logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dstFullName) } } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/TypeHierarchyPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/TypeHierarchyPass.scala index 5f96ba2e76dc..d1a3af47136d 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/TypeHierarchyPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/TypeHierarchyPass.scala @@ -1,10 +1,11 @@ package io.joern.x2cpg.passes.typerelations import io.shiftleft.codepropertygraph.generated.Cpg +import io.joern.x2cpg.utils.LinkingUtil import io.shiftleft.codepropertygraph.generated.nodes.TypeDecl import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.joern.x2cpg.utils.LinkingUtil +import io.shiftleft.semanticcpg.language.* /** Create INHERITS_FROM edges from `TYPE_DECL` nodes to `TYPE` nodes. */ diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ExternalCommand.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ExternalCommand.scala index 0e85d34636df..f337ba2a28c5 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ExternalCommand.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ExternalCommand.scala @@ -1,39 +1,87 @@ package io.joern.x2cpg.utils -import java.util.concurrent.ConcurrentLinkedQueue -import scala.sys.process.{Process, ProcessLogger} -import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters.* +import io.shiftleft.utils.IOUtils -trait ExternalCommand { - - protected val IsWin: Boolean = scala.util.Properties.isWin +import java.io.File +import java.nio.file.{Path, Paths} +import scala.jdk.CollectionConverters.* +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal - // do not prepend any shell layer by default - // individual frontends may override this - protected val shellPrefix: Seq[String] = Nil +object ExternalCommand { - protected def handleRunResult(result: Try[Int], stdOut: Seq[String], stdErr: Seq[String]): Try[Seq[String]] = { - result match { - case Success(0) => - Success(stdOut) - case _ => + case class ExternalCommandResult(exitCode: Int, stdOut: Seq[String], stdErr: Seq[String]) { + def successOption: Option[Seq[String]] = exitCode match { + case 0 => Some(stdOut) + case _ => None + } + def toTry: Try[Seq[String]] = exitCode match { + case 0 => Success(stdOut) + case nonZeroExitCode => val allOutput = stdOut ++ stdErr - Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) + val message = s"""Process exited with code $nonZeroExitCode. Output: + |${allOutput.mkString(System.lineSeparator())} + |""".stripMargin + Failure(new RuntimeException(message)) } } - def run(command: String, cwd: String, extraEnv: Map[String, String] = Map.empty): Try[Seq[String]] = { - val stdOutOutput = new ConcurrentLinkedQueue[String] - val stdErrOutput = new ConcurrentLinkedQueue[String] - val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) - val process = shellPrefix match { - case Nil => Process(command, new java.io.File(cwd), extraEnv.toList*) - case _ => Process(shellPrefix :+ command, new java.io.File(cwd), extraEnv.toList*) + def run( + command: Seq[String], + cwd: String, + mergeStdErrInStdOut: Boolean = false, + extraEnv: Map[String, String] = Map.empty + ): ExternalCommandResult = { + val builder = new ProcessBuilder() + .command(command.toArray*) + .directory(new File(cwd)) + .redirectErrorStream(mergeStdErrInStdOut) + builder.environment().putAll(extraEnv.asJava) + + val stdOutFile = File.createTempFile("x2cpg", "stdout") + val stdErrFile = Option.when(!mergeStdErrInStdOut)(File.createTempFile("x2cpg", "stderr")) + + try { + builder.redirectOutput(stdOutFile) + stdErrFile.foreach(f => builder.redirectError(f)) + + val process = builder.start() + val returnValue = process.waitFor() + + val stdOut = IOUtils.readLinesInFile(stdOutFile.toPath) + val stdErr = stdErrFile.map(f => IOUtils.readLinesInFile(f.toPath)).getOrElse(Seq.empty) + ExternalCommandResult(returnValue, stdOut, stdErr) + } catch { + case NonFatal(exception) => + ExternalCommandResult(1, Seq.empty, stdErr = Seq(exception.getMessage)) + } finally { + stdOutFile.delete() + stdErrFile.foreach(_.delete()) } - handleRunResult(Try(process.!(processLogger)), stdOutOutput.asScala.toSeq, stdErrOutput.asScala.toSeq) } -} + /** Finds the absolute path to the executable directory (e.g. `/path/to/javasrc2cpg/bin`). Based on the package path + * of a loaded classfile based on some (potentially flakey?) filename heuristics. Context: we want to be able to + * invoke the x2cpg frontends from any directory, not just their install directory, and then invoke other + * executables, like astgen, php-parser et al. + */ + def executableDir(packagePath: Path): Path = { + val packagePathAbsolute = packagePath.toAbsolutePath + val fixedDir = + if (packagePathAbsolute.toString.contains("lib")) { + var dir = packagePathAbsolute + while (dir.toString.contains("lib")) + dir = dir.getParent + dir + } else if (packagePathAbsolute.toString.contains("target")) { + var dir = packagePathAbsolute + while (dir.toString.contains("target")) + dir = dir.getParent + dir + } else { + Paths.get(".") + } -object ExternalCommand extends ExternalCommand + fixedDir.resolve("bin/").toAbsolutePath + } +} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/KeyPool.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/KeyPool.scala new file mode 100644 index 000000000000..0faa1f1fa216 --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/KeyPool.scala @@ -0,0 +1,80 @@ +package io.joern.x2cpg.utils + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +/** A pool of long integers. Using the method `next`, the pool provides the next id in a thread-safe manner. */ +trait KeyPool { + def next: Long +} + +/** A key pool that returns the integers of the interval [first, last] in a thread-safe manner. + */ +class IntervalKeyPool(val first: Long, val last: Long) extends KeyPool { + + /** Get next number in interval or raise if number is larger than `last` + */ + def next: Long = { + if (!valid) { + throw new IllegalStateException("Call to `next` on invalidated IntervalKeyPool.") + } + val n = cur.incrementAndGet() + if (n > last) { + throw new RuntimeException("Pool exhausted") + } else { + n + } + } + + /** Split key pool into `numberOfPartitions` partitions of mostly equal size. Invalidates the current pool to ensure + * that the user does not continue to use both the original pool and pools derived from it via `split`. + */ + def split(numberOfPartitions: Int): Iterator[IntervalKeyPool] = { + valid = false + if (numberOfPartitions == 0) { + Iterator() + } else { + val curFirst = cur.get() + val k = (last - curFirst) / numberOfPartitions + (1 to numberOfPartitions).map { i => + val poolFirst = curFirst + (i - 1) * k + new IntervalKeyPool(poolFirst, poolFirst + k - 1) + }.iterator + } + } + + private val cur: AtomicLong = new AtomicLong(first - 1) + private var valid: Boolean = true +} + +/** A key pool that returns elements of `seq` in order in a thread-safe manner. + */ +class SequenceKeyPool(seq: Seq[Long]) extends KeyPool { + + val seqLen: Int = seq.size + var cur = new AtomicInteger(-1) + + override def next: Long = { + val i = cur.incrementAndGet() + if (i >= seqLen) { + throw new RuntimeException("Pool exhausted") + } else { + seq(i) + } + } +} + +object KeyPoolCreator { + + /** Divide the keyspace into n intervals and return a list of corresponding key pools. + */ + def obtain(n: Long, minValue: Long = 0, maxValue: Long = Long.MaxValue): List[IntervalKeyPool] = { + val nIntervals = Math.max(n, 1) + val intervalLen: Long = (maxValue - minValue) / nIntervals + List.range(0L, nIntervals).map { i => + val first = i * intervalLen + minValue + val last = first + intervalLen - 1 + new IntervalKeyPool(first, last) + } + } + +} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/LinkingUtil.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/LinkingUtil.scala index 9fd6290b4d8b..6e44467fb722 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/LinkingUtil.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/LinkingUtil.scala @@ -1,92 +1,74 @@ package io.joern.x2cpg.utils -import io.joern.x2cpg.passes.frontend.Dereference -import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{Properties, PropertyNames} +import io.shiftleft.codepropertygraph.generated.{Cpg, Properties, PropertyNames} +import io.shiftleft.codepropertygraph.generated.nodes.NamespaceBlock +import io.shiftleft.codepropertygraph.generated.nodes.Type +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} -import overflowdb.traversal.* -import overflowdb.traversal.ChainedImplicitsTemp.* -import overflowdb.{Node, NodeDb, NodeRef, PropertyKey} -import scala.collection.mutable import scala.jdk.CollectionConverters.* trait LinkingUtil { - import overflowdb.BatchedUpdate.DiffGraphBuilder + import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder val MAX_BATCH_SIZE: Int = 100 val logger: Logger = LoggerFactory.getLogger(classOf[LinkingUtil]) def typeDeclFullNameToNode(cpg: Cpg, x: String): Option[TypeDecl] = - nodesWithFullName(cpg, x).collectFirst { case x: TypeDecl => x } + cpg.typeDecl.fullNameExact(x).headOption def typeFullNameToNode(cpg: Cpg, x: String): Option[Type] = - nodesWithFullName(cpg, x).collectFirst { case x: Type => x } + cpg.typ.fullNameExact(x).headOption def methodFullNameToNode(cpg: Cpg, x: String): Option[Method] = - nodesWithFullName(cpg, x).collectFirst { case x: Method => x } + cpg.method.fullNameExact(x).headOption def namespaceBlockFullNameToNode(cpg: Cpg, x: String): Option[NamespaceBlock] = - nodesWithFullName(cpg, x).collectFirst { case x: NamespaceBlock => x } - - def nodesWithFullName(cpg: Cpg, x: String): mutable.Seq[NodeRef[? <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + cpg.namespaceBlock.fullNameExact(x).headOption /** For all nodes `n` with a label in `srcLabels`, determine the value of `n.\$dstFullNameKey`, use that to lookup the * destination node in `dstNodeMap`, and create an edge of type `edgeType` between `n` and the destination node. */ - protected def linkToSingle( cpg: Cpg, - srcNodes: List[Node], + srcNodes: List[StoredNode], srcLabels: List[String], dstNodeLabel: String, edgeType: String, dstNodeMap: String => Option[StoredNode], dstFullNameKey: String, + dstDefaultPropertyValue: Any, dstGraph: DiffGraphBuilder, dstNotExistsHandler: Option[(StoredNode, String) => Unit] ): Unit = { - val dereference = Dereference(cpg) var loggedDeprecationWarning = false srcNodes.foreach { srcNode => // If the source node does not have any outgoing edges of this type // This check is just required for backward compatibility if (srcNode.outE(edgeType).isEmpty) { - val key = new PropertyKey[String](dstFullNameKey) srcNode - .propertyOption(key) - .filter { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - srcNode.propertyDefaultValue(dstFullNameKey) != dereferenceDstFullName - } - .ifPresent { dstFullName => + .propertyOption[String](dstFullNameKey) + .filter { dstFullName => dstDefaultPropertyValue != dstFullName } + .map { dstFullName => // for `UNKNOWN` this is not always set, so we're using an Option here - val srcStoredNode = srcNode.asInstanceOf[StoredNode] - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match { + dstNodeMap(dstFullName) match { case Some(dstNode) => - dstGraph.addEdge(srcStoredNode, dstNode, edgeType) + dstGraph.addEdge(srcNode, dstNode, edgeType) case None if dstNodeMap(dstFullName).isDefined => - dstGraph.addEdge(srcStoredNode, dstNodeMap(dstFullName).get, edgeType) + dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) case None if dstNotExistsHandler.isDefined => - dstNotExistsHandler.get(srcStoredNode, dereferenceDstFullName) + dstNotExistsHandler.get(srcNode, dstFullName) case _ => - logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) + logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dstFullName) } } } else { srcNode.out(edgeType).property(Properties.FullName).nextOption() match { - case Some(dstFullName) => - dstGraph.setNodeProperty( - srcNode.asInstanceOf[StoredNode], - dstFullNameKey, - dereference.dereferenceTypeFullName(dstFullName) - ) - case None => logger.info(s"Missing outgoing edge of type $edgeType from node $srcNode") + case Some(dstFullName) => dstGraph.setNodeProperty(srcNode, dstFullNameKey, dstFullName) + case None => logger.info(s"Missing outgoing edge of type $edgeType from node $srcNode") } if (!loggedDeprecationWarning) { logger.info( @@ -110,23 +92,21 @@ trait LinkingUtil { dstGraph: DiffGraphBuilder ): Unit = { var loggedDeprecationWarning = false - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels*).asScala.cast[SRC_NODE_TYPE].foreach { srcNode => + cpg.graph.nodes(srcLabels*).cast[SRC_NODE_TYPE].foreach { srcNode => if (!srcNode.outE(edgeType).hasNext) { getDstFullNames(srcNode).foreach { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match { + dstNodeMap(dstFullName) match { case Some(dstNode) => dstGraph.addEdge(srcNode, dstNode, edgeType) case None if dstNodeMap(dstFullName).isDefined => dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) case None => - logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) + logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dstFullName) } } } else { val dstFullNames = srcNode.out(edgeType).property(Properties.FullName).l - dstGraph.setNodeProperty(srcNode, dstFullNameKey, dstFullNames.map(dereference.dereferenceTypeFullName)) + dstGraph.setNodeProperty(srcNode, dstFullNameKey, dstFullNames) if (!loggedDeprecationWarning) { logger.info( s"Using deprecated CPG format with already existing $edgeType edge between" + diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ListUtils.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ListUtils.scala index dc3c75e8ec24..9d83c91fd5f8 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ListUtils.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/ListUtils.scala @@ -16,5 +16,8 @@ object ListUtils { case _ => Nil } } + + /** Returns the single element, or None if the list is empty or contains more than one element. */ + def singleOrNone: Option[T] = if list.size == 1 then list.headOption else None } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/DependencyResolver.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/DependencyResolver.scala index 4fd0d93df7b5..5cfdb1f50769 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/DependencyResolver.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/DependencyResolver.scala @@ -18,10 +18,8 @@ case class DependencyResolverParams( ) object DependencyResolver { - private val logger = LoggerFactory.getLogger(getClass) - private val defaultGradleProjectName = "app" - private val defaultGradleConfigurationName = "compileClasspath" - private val MaxSearchDepth: Int = 4 + private val logger = LoggerFactory.getLogger(getClass) + private val MaxSearchDepth: Int = 4 def getCoordinates( projectDir: Path, @@ -31,9 +29,10 @@ object DependencyResolver { if (isMavenBuildFile(buildFile)) // TODO: implement None - else if (isGradleBuildFile(buildFile)) - getCoordinatesForGradleProject(buildFile.getParent, defaultGradleConfigurationName) - else { + else if (isGradleBuildFile(buildFile)) { + // TODO: Don't limit this to the default configuration name + getCoordinatesForGradleProject(buildFile.getParent, "compileClasspath") + } else { logger.warn(s"Found unsupported build file $buildFile") Nil } @@ -46,7 +45,9 @@ object DependencyResolver { projectDir: Path, configuration: String ): Option[collection.Seq[String]] = { - val lines = ExternalCommand.run(s"gradle dependencies --configuration $configuration", projectDir.toString) match { + val lines = ExternalCommand + .run(Seq("gradle", "dependencies", "--configuration,", configuration), projectDir.toString) + .toTry match { case Success(lines) => lines case Failure(exception) => logger.warn( @@ -84,12 +85,14 @@ object DependencyResolver { projectDir: Path ): Option[collection.Seq[String]] = { logger.info("resolving Gradle dependencies at {}", projectDir) - val gradleProjectName = params.forGradle.getOrElse(GradleConfigKeys.ProjectName, defaultGradleProjectName) - val gradleConfiguration = - params.forGradle.getOrElse(GradleConfigKeys.ConfigurationName, defaultGradleConfigurationName) - GradleDependencies.get(projectDir, gradleProjectName, gradleConfiguration) match { - case Some(deps) => Some(deps) - case None => + val maybeProjectNameOverride = params.forGradle.get(GradleConfigKeys.ProjectName) + val maybeConfigurationOverride = params.forGradle.get(GradleConfigKeys.ConfigurationName) + + GradleDependencies.get(projectDir, maybeProjectNameOverride, maybeConfigurationOverride) match { + case dependenciesMap if dependenciesMap.values.exists(_.nonEmpty) => + Option(dependenciesMap.values.flatten.toSet.toSeq) + + case _ => logger.warn(s"Could not download Gradle dependencies for project at path `$projectDir`") None } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/GradleDependencies.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/GradleDependencies.scala index bc8c13918df7..f867cb86391b 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/GradleDependencies.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/GradleDependencies.scala @@ -1,19 +1,40 @@ package io.joern.x2cpg.utils.dependency -import better.files._ +import better.files.* import org.gradle.tooling.{GradleConnector, ProjectConnection} -import org.gradle.tooling.model.GradleProject +import org.gradle.tooling.model.{GradleProject, ProjectIdentifier, Task} import org.gradle.tooling.model.build.BuildEnvironment import org.slf4j.LoggerFactory import java.io.ByteArrayOutputStream import java.nio.file.{Files, Path} -import java.io.{File => JFile} +import java.io.File as JFile import java.util.stream.Collectors -import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.jdk.CollectionConverters.* import scala.util.{Failure, Random, Success, Try, Using} -case class GradleProjectInfo(gradleVersion: String, tasks: Seq[String], hasAndroidSubproject: Boolean = false) { +case class ProjectNameInfo(projectName: String, isSubproject: Boolean) { + override def toString: String = { + if (isSubproject) + s":$projectName" + else + projectName + } + + def makeGradleTaskName(taskName: String): String = { + if (isSubproject) + s"$projectName:$taskName" + else + taskName + } +} + +case class GradleProjectInfo( + subprojects: Map[ProjectNameInfo, List[String]], + gradleVersion: String, + hasAndroidSubproject: Boolean +) { def gradleVersionMajorMinor(): (Int, Int) = { def isValidPart(part: String) = part.forall(Character.isDigit) val parts = gradleVersion.split('.') @@ -27,45 +48,73 @@ case class GradleProjectInfo(gradleVersion: String, tasks: Seq[String], hasAndro } } -object Constants { - val aarFileExtension = "aar" - val gradleAndroidPropertyPrefix = "android." - val gradlePropertiesTaskName = "properties" - val jarInsideAarFileName = "classes.jar" -} - case class GradleDepsInitScript(contents: String, taskName: String, destinationDir: Path) object GradleDependencies { - private val logger = LoggerFactory.getLogger(getClass) - private val initScriptPrefix = "x2cpg.init.gradle" - private val taskNamePrefix = "x2cpgCopyDeps" - private val tempDirPrefix = "x2cpgDependencies" + private val aarFileExtension = "aar" + private val gradleAndroidPropertyPrefix = "android" + private val gradlePropertiesTaskName = "properties" + private val jarInsideAarFileName = "classes.jar" + private val defaultConfigurationName = "releaseRuntimeClasspath" + private val initScriptPrefix = "x2cpg.init.gradle" + private val taskNamePrefix = "x2cpgCopyDeps" + private val tempDirPrefix = "x2cpgDependencies" + private val defaultGradleAppName = "app" + + private val logger = LoggerFactory.getLogger(getClass) // works with Gradle 5.1+ because the script makes use of `task.register`: // https://docs.gradle.org/current/userguide/task_configuration_avoidance.html - private def gradle5OrLaterAndroidInitScript( - taskName: String, - destination: String, - gradleProjectName: String, - gradleConfigurationName: String - ): String = { + private def getInitScriptContent(taskName: String, destination: String, projectInfo: GradleProjectInfo): String = { + val projectConfigurationString = projectInfo.subprojects + .map { case (projectNameInfo, configurationNames) => + val quotedConfigurationNames = configurationNames.map(name => s"\"$name\"").mkString(", ") + s"\"${projectNameInfo.projectName}\": [$quotedConfigurationNames]" + } + .mkString(", ") + + val taskCreationFunction = projectInfo.gradleVersionMajorMinor() match { + case (major, minor) if major >= 5 && minor >= 1 => "tasks.register" + case _ => "tasks.create" + } + + val androidTaskDefinition = Option.when(projectInfo.hasAndroidSubproject)(s""" + |def androidDepsCopyTaskName = taskName + "_androidDeps" + | $taskCreationFunction(androidDepsCopyTaskName, Copy) { + | duplicatesStrategy = 'include' + | into destinationDir + | from project.configurations.find { it.name.equals("androidApis") } + | } + |""".stripMargin) + + val dependsOnAndroidTask = Option.when(projectInfo.hasAndroidSubproject)("dependsOn androidDepsCopyTaskName") + s""" |allprojects { | afterEvaluate { project -> | def taskName = "$taskName" | def destinationDir = "${destination.replaceAll("\\\\", "/")}" - | def gradleProjectName = "$gradleProjectName" - | def gradleConfigurationName = "$gradleConfigurationName" + | def gradleProjectConfigurations = [$projectConfigurationString] | - | if (project.name.equals(gradleProjectName)) { + | if (gradleProjectConfigurations.containsKey(project.name)) { + | def gradleConfigurationNames = gradleProjectConfigurations.get(project.name) + | | def compileDepsCopyTaskName = taskName + "_compileDeps" - | tasks.register(compileDepsCopyTaskName, Copy) { - | def selectedConfig = project.configurations.find { it.name.equals(gradleConfigurationName) } + | $taskCreationFunction(compileDepsCopyTaskName, Copy) { + | + | def selectedConfigs = project.configurations.findAll { + | configuration -> gradleConfigurationNames.contains(configuration.getName()) + | } + | | def componentIds = [] - | if (selectedConfig != null) { - | componentIds = selectedConfig.incoming.resolutionResult.allDependencies.collect { it.selected.id } + | if (!selectedConfigs.isEmpty()) { + | for (selectedConfig in selectedConfigs) { + | componentIds = selectedConfig.incoming.resolutionResult.allDependencies.findAll { + | dep -> dep instanceof org.gradle.api.internal.artifacts.result.DefaultResolvedDependencyResult + | } .collect { it.selected.id } + | } | } + | | def result = dependencies.createArtifactResolutionQuery() | .forComponents(componentIds) | .withArtifacts(JvmLibrary, SourcesArtifact) @@ -74,14 +123,9 @@ object GradleDependencies { | into destinationDir | from result.resolvedComponents.collect { it.getArtifacts(SourcesArtifact).collect { it.file } } | } - | def androidDepsCopyTaskName = taskName + "_androidDeps" - | tasks.register(androidDepsCopyTaskName, Copy) { - | duplicatesStrategy = 'include' - | into destinationDir - | from project.configurations.find { it.name.equals("androidApis") } - | } - | tasks.register(taskName, Copy) { - | dependsOn androidDepsCopyTaskName + | ${androidTaskDefinition.getOrElse("")} + | $taskCreationFunction(taskName, Copy) { + | ${dependsOnAndroidTask.getOrElse("")} | dependsOn compileDepsCopyTaskName | } | } @@ -90,40 +134,9 @@ object GradleDependencies { |""".stripMargin } - // this init script _should_ work with Gradle >=4, but has not been tested thoroughly - // TODO: add test cases for older Gradle versions - private def gradle5OrLaterInitScript( - taskName: String, - destination: String, - gradleConfigurationName: String - ): String = { - val into = destination.replaceAll("\\\\", "/") - val fromConfigurations = - Set(s"from configurations.$gradleConfigurationName", "from configurations.runtimeClasspath").mkString("\n") - s""" - |allprojects { - | apply plugin: 'java' - | task $taskName(type: Copy) { - | $fromConfigurations - | into "$into" - | } - |} - |""".stripMargin - } - - private def makeInitScript( - destinationDir: Path, - forAndroid: Boolean, - gradleProjectName: String, - gradleConfigurationName: String - ): GradleDepsInitScript = { + private def makeInitScript(destinationDir: Path, projectInfo: GradleProjectInfo): GradleDepsInitScript = { val taskName = taskNamePrefix + "_" + (Random.alphanumeric take 8).toList.mkString - val content = - if (forAndroid) { - gradle5OrLaterAndroidInitScript(taskName, destinationDir.toString, gradleProjectName, gradleConfigurationName) - } else { - gradle5OrLaterInitScript(taskName, destinationDir.toString, gradleConfigurationName) - } + val content = getInitScriptContent(taskName, destinationDir.toString, projectInfo) GradleDepsInitScript(content, taskName, destinationDir) } @@ -131,39 +144,170 @@ object GradleDependencies { GradleConnector.newConnector().forProjectDirectory(projectDir).connect() } - private def getGradleProjectInfo(projectDir: Path, projectName: String): Option[GradleProjectInfo] = { + private def getConfigurationsWithDependencies(dependenciesOutput: String): List[String] = { + // TODO: this is a heuristic for matching configuration names based on a sample of open source projects. + // either add more options to this or revise the approach completely if this turns out to miss too much. + val configurationNameRegex = raw"(\S*([rR]elease|[rR]untime)\S*) -.+$$".r + val lines = dependenciesOutput.lines.iterator().asScala + val results = mutable.Set[String]() + + while (lines.hasNext) { + val line = lines.next() + line match { + case configurationNameRegex(configurationName, _) if lines.hasNext => + val next = lines.next() + if (next != "No dependencies") { + results.addOne(configurationName) + } + lines.takeWhile(_.nonEmpty) + + case _ => + lines.takeWhile(_.nonEmpty) + } + } + + results.filterNot(_.toLowerCase.contains("test")).toList + } + + private def getGradleProjectInfo( + projectDir: Path, + projectNameOverride: Option[String], + configurationNameOverride: Option[String] + ): Option[GradleProjectInfo] = { Try(makeConnection(projectDir.toFile)) match { case Success(gradleConnection) => Using.resource(gradleConnection) { connection => try { val buildEnv = connection.getModel[BuildEnvironment](classOf[BuildEnvironment]) val project = connection.getModel[GradleProject](classOf[GradleProject]) - val hasAndroidPrefixGradleProperty = - runGradleTask(connection, Constants.gradlePropertiesTaskName) match { + + val availableProjectNames = ProjectNameInfo(project.getName, false) :: project.getChildren.asScala + .map(child => ProjectNameInfo(child.getName, true)) + .toList + + val availableProjectNamesString = availableProjectNames.mkString(" ") + + logger.debug(s"Found gradle project names ${availableProjectNames.mkString(" ")}") + + val selectedProjectNames = if (projectNameOverride.isDefined) { + val overrideName = projectNameOverride.get + availableProjectNames.find(_.projectName == overrideName) match { + case Some(projectInfo) => + logger.debug(s"Only fetching dependencies for overridden project name $overrideName") + projectInfo :: Nil + + case None => + logger.warn( + s"Project name override was specified for dependency fetching ($overrideName), but no such project found." + ) + logger.warn( + s"Falling back to fetching dependencies for all available project names: $availableProjectNamesString" + ) + availableProjectNames + } + } else { + availableProjectNames.find(_.projectName == defaultGradleAppName) match { + case Some(defaultProjectInfo) => + // TODO: This is a temporary check to avoid issues that could arise from subprojects using conflicting + // versions of dependencies. Ideally dependencies for all of these projects will be fetched with + // any conflicts handled in the consumer. + logger.debug(s"Found project with default name ($defaultGradleAppName)") + logger.debug(s"Fetching dependencies only for default project ($defaultGradleAppName)") + defaultProjectInfo :: Nil + + case None => + logger.debug(s"No project name override or project with default name ($defaultGradleAppName) found.") + logger.debug(s"Fetching dependencies for all available projects: $availableProjectNamesString") + availableProjectNames + } + } + + val selectedConfigurations = selectedProjectNames.flatMap { projectNameInfo => + val dependenciesTaskName = projectNameInfo.makeGradleTaskName("dependencies") + + val availableConfigurations = runGradleTask(connection, dependenciesTaskName) match { case Some(out) => - out.split('\n').exists(_.startsWith(Constants.gradleAndroidPropertyPrefix)) - case None => false + getConfigurationsWithDependencies(out) match { + case Nil => + logger.debug(s"No configurations with dependencies found for project $projectNameInfo") + Nil + case deps => + logger.debug( + s"Found the following configurations with dependencies for project $projectNameInfo: ${deps.mkString(", ")}" + ) + deps + } + case None => + logger.warn(s"Failure executing dependencies task $dependenciesTaskName") + Nil } - val info = GradleProjectInfo( - buildEnv.getGradle.getGradleVersion, - project.getTasks.asScala.map(_.getName).toSeq, - hasAndroidPrefixGradleProperty - ) - if (hasAndroidPrefixGradleProperty) { - val validProjectNames = List(project.getName) ++ project.getChildren.getAll.asScala.map(_.getName) - logger.debug(s"Found Gradle projects: ${validProjectNames.mkString(",")}") - if (!validProjectNames.contains(projectName)) { - val validProjectNamesStr = validProjectNames.mkString(",") - logger.warn( - s"The provided Gradle project name `$projectName` is is not part of the valid project names: `$validProjectNamesStr`" - ) - None + + val availableConfigurationsString = availableConfigurations.mkString(", ") + + val selectedConfigurations = if (availableConfigurations.isEmpty) { + // Skip logging below, since no available configurations already logged + Nil + } else if (configurationNameOverride.isDefined) { + val overrideName = configurationNameOverride.get + availableConfigurations.find(_ == overrideName) match { + case Some(configurationName) => + logger.debug(s"Only fetching dependencies for overridden configuration $overrideName") + configurationName :: Nil + + case None => + logger.warn( + s"Configuration name override was specified for dependency fetching ($overrideName), but no such configuration found for project $projectNameInfo." + ) + logger.warn( + s"Falling back to fetching dependencies for all available configurations: $availableConfigurationsString" + ) + availableConfigurations + } } else { - Some(info) + availableConfigurations.find(_ == defaultConfigurationName) match { + case Some(defaultConfigurationName) => + // TODO: This is a temporary check to avoid issues that could arise from subprojects using conflicting + // versions of dependencies. Ideally dependencies for all of these configurations will be fetched with + // any conflicts handled in the consumer. + logger.debug( + s"Found default configuration name ($defaultConfigurationName) for project $projectNameInfo" + ) + logger.debug( + s"Fetching dependencies only for default configuration ($defaultConfigurationName) for project $projectNameInfo" + ) + defaultConfigurationName :: Nil + + case None => + logger.debug( + s"No configuration override or configuration with default name ($defaultConfigurationName) found for project $projectNameInfo." + ) + logger.debug( + s"Fetching dependencies for all available configurations for project $projectNameInfo: $availableConfigurationsString" + ) + availableConfigurations + } + } + + Option.when(selectedConfigurations.nonEmpty) { + projectNameInfo -> selectedConfigurations + } + }.toMap + + val includesAndroidProject = selectedProjectNames.exists { projectNameInfo => + val propertiesTaskName = projectNameInfo.makeGradleTaskName(gradlePropertiesTaskName) + + runGradleTask(connection, propertiesTaskName) match { + case Some(out) => + out.lines().iterator().asScala.exists(_.startsWith(gradleAndroidPropertyPrefix)) + case None => false } - } else { - Some(info) } + + val gradleVersion = buildEnv.getGradle.getGradleVersion + + val gradleProjectInfo = GradleProjectInfo(selectedConfigurations, gradleVersion, includesAndroidProject) + + Option(gradleProjectInfo) } catch { case t: Throwable => logger.warn(s"Caught exception while trying use Gradle connection: ${t.getMessage}") @@ -198,15 +342,17 @@ object GradleDependencies { private def runGradleTask( connection: ProjectConnection, - initScript: GradleDepsInitScript, + taskName: String, + destinationDir: Path, initScriptPath: String ): Option[collection.Seq[String]] = { Using.resources(new ByteArrayOutputStream, new ByteArrayOutputStream) { case (stdoutStream, stderrStream) => - logger.info(s"Executing gradle task '${initScript.taskName}'...") + logger.debug(s"Executing gradle task '${taskName}'...") + Try( connection .newBuild() - .forTasks(initScript.taskName) + .forTasks(taskName) .withArguments("--init-script", initScriptPath) .setStandardOutput(stdoutStream) .setStandardError(stderrStream) @@ -215,14 +361,26 @@ object GradleDependencies { case Success(_) => val result = Files - .list(initScript.destinationDir) + .list(destinationDir) .collect(Collectors.toList[Path]) .asScala .map(_.toAbsolutePath.toString) - logger.info(s"Resolved `${result.size}` dependency files.") + logger.info(s"Task $taskName resolved `${result.size}` dependency files.") Some(result) case Failure(ex) => logger.warn(s"Caught exception while executing Gradle task: ${ex.getMessage}") + val androidSdkError = "Define a valid SDK location with an ANDROID_HOME environment variable" + if (stderrStream.toString.contains(androidSdkError)) { + logger.warn( + "A missing Android SDK configuration caused gradle dependency fetching failures. Please define a valid SDK location with an ANDROID_HOME environment variable or by setting the sdk.dir path in your project's local properties file" + ) + } + if (stderrStream.toString.contains("Could not compile initialization script")) { + val scriptContents = File(initScriptPath).contentAsString + logger.debug( + s"########## INITIALIZATION_SCRIPT ##########\n$scriptContents\n###########################################" + ) + } logger.debug(s"Gradle task execution stdout: \n$stdoutStream") logger.debug(s"Gradle task execution stderr: \n$stderrStream") None @@ -231,14 +389,14 @@ object GradleDependencies { } private def extractClassesJarFromAar(aar: File): Option[Path] = { - val newPath = aar.path.toString.replaceFirst(Constants.aarFileExtension + "$", "jar") + val newPath = aar.path.toString.replaceFirst(aarFileExtension + "$", "jar") val aarUnzipDirSuffix = ".unzipped" val outDir = File(aar.path.toString + aarUnzipDirSuffix) - aar.unzipTo(outDir, _.getName == Constants.jarInsideAarFileName) + aar.unzipTo(outDir, _.getName == jarInsideAarFileName) val outFile = File(newPath) val classesJarEntries = outDir.listRecursively - .filter(_.path.getFileName.toString == Constants.jarInsideAarFileName) + .filter(_.path.getFileName.toString == jarInsideAarFileName) .toList if (classesJarEntries.size != 1) { logger.warn(s"Found aar file without `classes.jar` inside at path ${aar.path}") @@ -258,60 +416,61 @@ object GradleDependencies { // a destination directory. private[dependency] def get( projectDir: Path, - projectName: String, - configurationName: String - ): Option[collection.Seq[String]] = { - logger.info(s"Fetching Gradle project information at path `$projectDir` with project name `$projectName`.") - getGradleProjectInfo(projectDir, projectName) match { + projectNameOverride: Option[String], + configurationNameOverride: Option[String] + ): Map[String, collection.Seq[String]] = { + logger.info(s"Fetching Gradle project information at path `$projectDir`.") + getGradleProjectInfo(projectDir, projectNameOverride, configurationNameOverride) match { case Some(projectInfo) if projectInfo.gradleVersionMajorMinor()._1 < 5 => logger.warn(s"Unsupported Gradle version `${projectInfo.gradleVersion}`") - None + Map.empty + case Some(projectInfo) => Try(File.newTemporaryDirectory(tempDirPrefix).deleteOnExit()) match { case Success(destinationDir) => Try(File.newTemporaryFile(initScriptPrefix).deleteOnExit()) match { case Success(initScriptFile) => - val initScript = - makeInitScript(destinationDir.path, projectInfo.hasAndroidSubproject, projectName, configurationName) + val initScript = makeInitScript(destinationDir.path, projectInfo) initScriptFile.write(initScript.contents) - logger.info( - s"Downloading dependencies for configuration `$configurationName` of project `$projectName` at `$projectDir` into `$destinationDir`..." - ) Try(makeConnection(projectDir.toFile)) match { case Success(connection) => Using.resource(connection) { c => - runGradleTask(c, initScript, initScriptFile.pathAsString) match { - case Some(deps) => - Some(deps.map { d => - if (!d.endsWith(Constants.aarFileExtension)) d + projectInfo.subprojects.keys.flatMap { projectNameInfo => + val taskName = projectNameInfo.makeGradleTaskName(initScript.taskName) + + runGradleTask(c, taskName, initScript.destinationDir, initScriptFile.pathAsString) map { deps => + val depsOutput = deps.map { d => + if (!d.endsWith(aarFileExtension)) d else extractClassesJarFromAar(File(d)) match { case Some(path) => path.toString case None => d } - }) - case None => None - } + } + + projectNameInfo.projectName -> depsOutput + } + }.toMap } case Failure(ex) => logger.warn(s"Caught exception while trying to establish a Gradle connection: ${ex.getMessage}") logger.debug(s"Full exception: ", ex) - None + Map.empty } case Failure(ex) => logger.warn(s"Could not create temporary file for Gradle init script: ${ex.getMessage}") logger.debug(s"Full exception: ", ex) - None + Map.empty } case Failure(ex) => logger.warn(s"Could not create temporary directory for saving dependency files: ${ex.getMessage}") logger.debug("Full exception: ", ex) - None + Map.empty } case None => logger.warn("Could not fetch Gradle project information") - None + Map.empty } } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/MavenDependencies.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/MavenDependencies.scala index 4594359e9a31..e7323085ee82 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/MavenDependencies.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/dependency/MavenDependencies.scala @@ -14,17 +14,25 @@ object MavenDependencies { // also separate this from fetchCommandWithOpts to log a version that clearly separates options we provide from // options specified by the user via the MAVEN_CLI_OPTS environment variable, while also making it clear that this // environment variable is being considered. - private val fetchCommand = - s"mvn $$$MavenCliOpts --fail-never -B dependency:build-classpath -DincludeScope=compile -Dorg.slf4j.simpleLogger.defaultLogLevel=info -Dorg.slf4j.simpleLogger.logFile=System.out" + private val fetchArgs = + Vector( + "--fail-never", + "-B", + "dependency:build-classpath", + "-DincludeScope=compile", + "-Dorg.slf4j.simpleLogger.defaultLogLevel=info", + "-Dorg.slf4j.simpleLogger.logFile=System.out" + ) - private val fetchCommandWithOpts = { + private val fetchCommandWithOpts: Seq[String] = { // These options suppress output, so if they're provided we won't get any results. // "-q" and "--quiet" are the only ones that would realistically be used. val optionsToStrip = Set("-h", "--help", "-q", "--quiet", "-v", "--version") - val mavenOpts = Option(System.getenv(MavenCliOpts)).getOrElse("") - val mavenOptsStripped = mavenOpts.split(raw"\s").filterNot(optionsToStrip.contains).mkString(" ") - fetchCommand.replace(s"$$$MavenCliOpts", mavenOptsStripped) + val cli = org.apache.commons.exec.CommandLine("mvn") + cli.addArguments(System.getenv(MavenCliOpts), false) // a null from getenv() does not add any argument + + cli.toStrings.toIndexedSeq.filterNot(optionsToStrip.contains) ++ fetchArgs } private def logErrors(output: String): Unit = { @@ -34,13 +42,13 @@ object MavenDependencies { "The compile class path may be missing or partial.\n" + "Results will suffer from poor type information.\n" + "To fix this issue, please ensure that the below command can be executed successfully from the project root directory:\n" + - fetchCommand + "\n\n", - output + s"mvn $$$MavenCliOpts " + fetchArgs.mkString(" ") + "\n\n" ) + logger.debug(s"Full maven error output:\n$output") } private[dependency] def get(projectDir: Path): Option[collection.Seq[String]] = { - val lines = ExternalCommand.run(fetchCommandWithOpts, projectDir.toString) match { + val lines = ExternalCommand.run(fetchCommandWithOpts, projectDir.toString).toTry match { case Success(lines) => if (lines.contains("[INFO] Build failures were ignored.")) { logErrors(lines.mkString(System.lineSeparator())) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPClient.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPClient.scala new file mode 100644 index 000000000000..f32c2ad88ddb --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPClient.scala @@ -0,0 +1,64 @@ +package io.joern.x2cpg.utils.server + +import java.io.IOException +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.URI +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse.BodyHandlers +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +/** Represents an HTTP client for interacting with a frontend server. + * + * This class provides functionality to create and send HTTP requests to a specified frontend server. The server's host + * needs to be configured. + * + * @param port + * The port of the frontend server. + */ +case class FrontendHTTPClient(port: Int) { + + /** The underlying HTTP client used to send requests. */ + private val underlyingClient: HttpClient = HttpClient.newBuilder().build() + + /** Builds an HTTP POST request with the given arguments. + * + * The request is sent to the configured host and port, with a URI path defined by `FrontendHTTPDefaults.route`. The + * request body is constructed from the `args` array, which is concatenated into a single string separated by "&" and + * sent as `application/x-www-form-urlencoded`. + * + * @param args + * An array of arguments to be included in the POST request body. + * @return + * The constructed `HttpRequest` object. + */ + def buildRequest(args: Array[String]): HttpRequest = { + HttpRequest + .newBuilder() + .uri(URI.create(s"http://localhost:$port/run")) + .header("Content-Type", "application/x-www-form-urlencoded") + .POST(BodyPublishers.ofString(args.mkString("&"))) + .build() + } + + /** Sends the given HTTP request and returns the response body if successful. + * + * This method sends the provided `HttpRequest` built with `buildRequest` using the underlying HTTP client. If the + * response status code is 200, the response body is returned as a `Success`. If the status code indicates a failure, + * a `Failure` containing an `IOException` with the error details is returned. + * + * @param req + * The `HttpRequest` to be sent. + * @return + * A `Try[String]` containing the response body in case of success, or an `IOException` in case of failure. + */ + def sendRequest(req: HttpRequest): Try[String] = { + val resp = underlyingClient.send(req, BodyHandlers.ofString()) + resp match { + case r if r.statusCode() == 200 => Success(resp.body()) + case r => Failure(new IOException(s"Sending request failed with code ${r.statusCode()}: ${r.body()}")) + } + } +} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPServer.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPServer.scala new file mode 100644 index 000000000000..14ec8c04bc88 --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/utils/server/FrontendHTTPServer.scala @@ -0,0 +1,184 @@ +package io.joern.x2cpg.utils.server + +import io.joern.x2cpg.X2Cpg +import io.joern.x2cpg.X2CpgConfig +import io.joern.x2cpg.X2CpgFrontend +import io.joern.x2cpg.X2CpgMain +import net.freeutils.httpserver.HTTPServer +import net.freeutils.httpserver.HTTPServer.Context +import org.slf4j.LoggerFactory + +import java.util.concurrent.Executors +import java.util.concurrent.ExecutorService +import scala.annotation.tailrec +import scala.jdk.CollectionConverters.ListHasAsScala +import scala.util.Failure +import scala.util.Random +import scala.util.Success +import scala.util.Try + +/** Companion object for `FrontendHTTPServer` providing default executor configurations. */ +object FrontendHTTPServer { + + /** ExecutorService for single-threaded execution. */ + def singleThreadExecutor(): ExecutorService = Executors.newSingleThreadExecutor() + + /** ExecutorService for cached thread pool execution. */ + def cachedThreadPoolExecutor(): ExecutorService = Executors.newCachedThreadPool() + + /** Default ExecutorService used by `FrontendHTTPServer`. */ + def defaultExecutor(): ExecutorService = cachedThreadPoolExecutor() + +} + +/** A trait representing a frontend HTTP server for handling operations any subclass of `X2CpgMain` may offer via its + * main function. This trait provides methods and configurations for setting up an HTTP server that processes requests + * related to `X2CpgMain`. It includes handling request execution either in a single-threaded or multi-threaded manner, + * depending on the executor configuration. + * + * @tparam T + * The type parameter representing the X2Cpg configuration. + * @tparam X + * The type parameter representing the X2Cpg frontend. + */ +trait FrontendHTTPServer[T <: X2CpgConfig[T], X <: X2CpgFrontend[T]] { this: X2CpgMain[T, X] => + + /** Logger instance for logging server-related information. */ + private val logger = LoggerFactory.getLogger(this.getClass) + + /** Optionally holds the underlying HTTP server instance. */ + private var underlyingServer: Option[HTTPServer] = None + + /** Creates a new default configuration for the inheriting `X2CpgFrontend`. + * + * This method should be overridden by implementations to provide the default configuration object needed for the + * `X2CpgFrontend` operation. + * + * @return + * A new instance of the configuration `T`. + */ + protected def newDefaultConfig(): T + + /** ExecutorService used to execute HTTP requests. + * + * This can be overridden to switch between single-threaded and multi-threaded execution. By default, it uses the + * cached thread pool executor from `FrontendHTTPServer`. + */ + protected val executor: ExecutorService = FrontendHTTPServer.defaultExecutor() + + /** Handler for HTTP requests, providing functionality to handle specific routes. + * + * @param server + * The underlying HTTP server instance. + */ + protected class FrontendHTTPHandler(val server: HTTPServer) { + + /** Handles POST requests to the "/run" endpoint. + * + * This method is annotated to handle POST requests directed to the `/run` path. The request `req` is expected to + * include `input`, `output`, and (optionally) frontend arguments (unbounded). The request is expected to be sent + * `application/x-www-form-urlencoded`. The provided `X2CpgFrontend` is run with these input/output/arguments and + * the resulting CPG output path is returned in the response `resp` and status code 200. In case of a failure, + * status code 400 is sent together with a response containing the reason. + * + * @param req + * The HTTP request received by the server. + * @param resp + * The HTTP response to be sent by the server. + * @return + * The HTTP status code for the response. + */ + @Context(value = "/run", methods = Array("POST")) + def run(req: server.Request, resp: server.Response): Int = { + resp.getHeaders.add("Content-Type", "text/plain") + resp.getHeaders.add("Connection", "close") + + val params = req.getParamsList.asScala + val outputDir = params + .collectFirst { case Array(arg, value) if arg == "output" => value } + .getOrElse(X2CpgConfig.defaultOutputPath) + val arguments = params.collect { + case Array(arg, value) if arg == "input" => Array(value) + case Array(arg, value) if value.strip().isEmpty => Array(s"--$arg") + case Array(arg, value) => Array(s"--$arg", value) + }.flatten + logger.debug("Got POST with arguments: " + arguments.mkString(" ")) + + val config = X2Cpg + .parseCommandLine(arguments.toArray, cmdLineParser, newDefaultConfig()) + .getOrElse(newDefaultConfig()) + Try(frontend.run(config)) match { + case Failure(exception) => + resp.send(400, exception.getMessage) + case Success(_) => + resp.send(200, outputDir) + } + 0 + } + } + + /** Stops the underlying HTTP server if it is running. + * + * This method checks if the `underlyingServer` is defined and, if so, stops the server. It also logs a debug message + * indicating that the server has been stopped. If the server is not running, this method does nothing. + */ + def stop(): Unit = { + underlyingServer.foreach { server => + executor.shutdown() + server.stop() + logger.debug("Server stopped.") + } + } + + private def randomPort(): Int = { + val random = new Random() + 10000 + random.nextInt(65000) + } + + private def internalServerStart(): Try[Int] = { + val port = randomPort() + try { + val server = new HTTPServer(port) + val host = server.getVirtualHost(null) + host.addContexts(new FrontendHTTPHandler(server)) + server.setExecutor(executor) + server.start() + underlyingServer = Some(server) + Success(port) + } catch { + case exception: Throwable => Failure(exception) + } finally { + Runtime.getRuntime.addShutdownHook(new Thread(() => { + stop() + })) + } + } + + private def retryUntilSuccess[F](f: () => Try[F], maxAttempts: Int): F = { + @tailrec + def attempt(remainingAttempts: Int): F = { + f() match { + case Success(port) => port + case Failure(_) if remainingAttempts > 1 => attempt(remainingAttempts - 1) + case Failure(exception) => throw exception + } + } + attempt(maxAttempts) + } + + /** Starts the HTTP server. + * + * This method initializes the `underlyingServer`, sets the executor, and adds the appropriate contexts using the + * `FrontendHTTPHandler`. It then starts the server and prints the server's port to stdout. Additionally, a shutdown + * hook is added to ensure that the server is properly stopped when the application is terminated. + * + * @return + * The port this server is bound to which is chosen randomly until success (default number of attempts: 10) + */ + def startup(): Int = { + val port = retryUntilSuccess(internalServerStart, maxAttempts = 10) + println(s"FrontendHTTPServer started on port $port") + port + } + +} diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala index 86fbb6c1860c..15dbf7497572 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala @@ -1,9 +1,9 @@ package io.joern.x2cpg -import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewClosureBinding, NewIdentifier} +import flatgraph.SchemaViolationException +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, Call, NewCall, NewClosureBinding, NewIdentifier} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb.SchemaViolationException class AstTests extends AnyWordSpec with Matchers { @@ -35,7 +35,7 @@ class AstTests extends AnyWordSpec with Matchers { copied.root match { case Some(root: NewCall) => root should not be Some(moo) - root.properties("NAME") shouldBe "moo" + root.properties(Call.PropertyNames.Name) shouldBe "moo" root.argumentIndex shouldBe 123 case _ => fail() } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala index 2ebc2ffab906..cddfdd11d2e6 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala @@ -1,6 +1,6 @@ package io.joern.x2cpg -import better.files._ +import better.files.* import io.joern.x2cpg.utils.IgnoreInWindows import io.shiftleft.utils.ProjectRoot import org.scalatest.matchers.should.Matchers @@ -8,7 +8,7 @@ import org.scalatest.wordspec.AnyWordSpec import org.scalatest.Inside import java.nio.file.attribute.PosixFilePermissions -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Try import java.io.FileNotFoundException @@ -53,6 +53,24 @@ class SourceFilesTests extends AnyWordSpec with Matchers with Inside { } + "do not throw an exception" when { + "one of the input files is a broken symlink" in { + File.usingTemporaryDirectory() { tmpDir => + (tmpDir / "a.c").touch() + val symlink = (tmpDir / "broken.c").symbolicLinkTo(File("does/not/exist.c")) + symlink.exists shouldBe false + symlink.isReadable shouldBe false + val ignored = (tmpDir / "ignored.c").touch() + val result = Try( + SourceFiles + .determine(tmpDir.canonicalPath, cSourceFileExtensions, ignoredFilesPath = Some(Seq(ignored.pathAsString))) + ) + result.isFailure shouldBe false + result.getOrElse(List.empty).size shouldBe 1 + } + } + } + "throw an exception" when { "the input file does not exist" in { diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala index 18ea217c43ef..7c9acf560a51 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala @@ -12,8 +12,7 @@ class X2CpgTests extends AnyWordSpec with Matchers { "create an empty in-memory CPG when no output path is given" in { val cpg = X2Cpg.newEmptyCpg(None) - cpg.graph.V.hasNext shouldBe false - cpg.graph.E.hasNext shouldBe false + cpg.graph.allNodes.hasNext shouldBe false cpg.close() } @@ -22,9 +21,9 @@ class X2CpgTests extends AnyWordSpec with Matchers { file.delete() file.exists shouldBe false val cpg = X2Cpg.newEmptyCpg(Some(file.path.toString)) + cpg.close() file.exists shouldBe true Files.size(file.path) should not be 0 - cpg.close() } "overwrite existing file to create empty CPG" in { @@ -32,11 +31,10 @@ class X2CpgTests extends AnyWordSpec with Matchers { file.exists shouldBe true Files.size(file.path) shouldBe 0 val cpg = X2Cpg.newEmptyCpg(Some(file.path.toString)) - cpg.graph.V.hasNext shouldBe false - cpg.graph.E.hasNext shouldBe false + cpg.graph.allNodes.hasNext shouldBe false + cpg.close() file.exists shouldBe true Files.size(file.path) should not be 0 - cpg.close() } } } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala index 9b4d4e59c4d5..eadfe3f51682 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala @@ -1,48 +1,62 @@ package io.joern.x2cpg.passes -import io.shiftleft.OverflowDbTestInstance +import flatgraph.misc.TestUtils.* import io.joern.x2cpg.passes.controlflow.cfgdominator.{CfgAdapter, CfgDominator, CfgDominatorFrontier, DomTreeAdapter} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{NewUnknown, StoredNode} +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ - -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class CfgDominatorFrontierTests extends AnyWordSpec with Matchers { - private class TestCfgAdapter extends CfgAdapter[Node] { - override def successors(node: Node): IterableOnce[Node] = - node.out("CFG").asScala + private class TestCfgAdapter extends CfgAdapter[StoredNode] { + override def successors(node: StoredNode): Iterator[StoredNode] = + node.out("CFG").cast[StoredNode] - override def predecessors(node: Node): IterableOnce[Node] = - node.in("CFG").asScala + override def predecessors(node: StoredNode): Iterator[StoredNode] = + node.in("CFG").cast[StoredNode] } - private class TestDomTreeAdapter(immediateDominators: scala.collection.Map[Node, Node]) extends DomTreeAdapter[Node] { - override def immediateDominator(cfgNode: Node): Option[Node] = { + private class TestDomTreeAdapter(immediateDominators: scala.collection.Map[StoredNode, StoredNode]) + extends DomTreeAdapter[StoredNode] { + override def immediateDominator(cfgNode: StoredNode): Option[StoredNode] = { immediateDominators.get(cfgNode) } } "Cfg dominance frontier test" in { - val graph = OverflowDbTestInstance.create - - val v0 = graph + "UNKNOWN" - val v1 = graph + "UNKNOWN" - val v2 = graph + "UNKNOWN" - val v3 = graph + "UNKNOWN" - val v4 = graph + "UNKNOWN" - val v5 = graph + "UNKNOWN" - val v6 = graph + "UNKNOWN" - - v0 --- "CFG" --> v1 - v1 --- "CFG" --> v2 - v2 --- "CFG" --> v3 - v2 --- "CFG" --> v5 - v3 --- "CFG" --> v4 - v4 --- "CFG" --> v2 - v4 --- "CFG" --> v5 - v5 --- "CFG" --> v6 + val cpg = Cpg.empty + val graph = cpg.graph + + val v0 = graph.addNode(NewUnknown()) + val v1 = graph.addNode(NewUnknown()) + val v2 = graph.addNode(NewUnknown()) + val v3 = graph.addNode(NewUnknown()) + val v4 = graph.addNode(NewUnknown()) + val v5 = graph.addNode(NewUnknown()) + val v6 = graph.addNode(NewUnknown()) + + // TODO MP get arrow syntax back +// v0 --- "CFG" --> v1 +// v1 --- "CFG" --> v2 +// v2 --- "CFG" --> v3 +// v2 --- "CFG" --> v5 +// v3 --- "CFG" --> v4 +// v4 --- "CFG" --> v2 +// v4 --- "CFG" --> v5 +// v5 --- "CFG" --> v6 + graph.applyDiff { diffGraphBuilder => + diffGraphBuilder.addEdge(v0, v1, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v1, v2, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v2, v3, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v2, v5, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v3, v4, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v4, v2, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v4, v5, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v5, v6, EdgeTypes.CFG) + } val cfgAdapter = new TestCfgAdapter val cfgDominatorCalculator = new CfgDominator(cfgAdapter) @@ -50,7 +64,7 @@ class CfgDominatorFrontierTests extends AnyWordSpec with Matchers { val domTreeAdapter = new TestDomTreeAdapter(immediateDominators) val cfgDominatorFrontier = new CfgDominatorFrontier(cfgAdapter, domTreeAdapter) - val dominanceFrontier = cfgDominatorFrontier.calculate(graph.nodes.asScala.toList) + val dominanceFrontier = cfgDominatorFrontier.calculate(cpg.all) dominanceFrontier.get(v0) shouldBe None dominanceFrontier.get(v1) shouldBe None @@ -62,14 +76,20 @@ class CfgDominatorFrontierTests extends AnyWordSpec with Matchers { } "Cfg domiance frontier with dead code test" in { - val graph = OverflowDbTestInstance.create - - val v0 = graph + "UNKNOWN" - val v1 = graph + "UNKNOWN" // This node simulates dead code as it is not reachable from the entry v0. - val v2 = graph + "UNKNOWN" - - v0 --- "CFG" --> v2 - v1 --- "CFG" --> v2 + val cpg = Cpg.empty + val graph = cpg.graph + + val v0 = graph.addNode(NewUnknown()) + val v1 = graph.addNode(NewUnknown()) // This node simulates dead code as it is not reachable from the entry v0. + val v2 = graph.addNode(NewUnknown()) + + // TODO MP get arrow syntax back +// v0 --- "CFG" --> v2 +// v1 --- "CFG" --> v2 + graph.applyDiff { diffGraphBuilder => + diffGraphBuilder.addEdge(v0, v2, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v1, v2, EdgeTypes.CFG) + } val cfgAdapter = new TestCfgAdapter val cfgDominatorCalculator = new CfgDominator(cfgAdapter) @@ -77,7 +97,7 @@ class CfgDominatorFrontierTests extends AnyWordSpec with Matchers { val domTreeAdapter = new TestDomTreeAdapter(immediateDominators) val cfgDominatorFrontier = new CfgDominatorFrontier(cfgAdapter, domTreeAdapter) - val dominanceFrontier = cfgDominatorFrontier.calculate(graph.nodes.asScala.toList) + val dominanceFrontier = cfgDominatorFrontier.calculate(cpg.all) dominanceFrontier.get(v0) shouldBe None dominanceFrontier.apply(v1) shouldBe Set(v2) diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala index c8955335485a..1b5b615676ea 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala @@ -1,80 +1,93 @@ package io.joern.x2cpg.passes -import io.shiftleft.OverflowDbTestInstance -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes} +import flatgraph.misc.TestUtils.* import io.joern.x2cpg.passes.controlflow.cfgdominator.CfgDominatorPass +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{NewMethod, NewMethodReturn, NewUnknown} +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class CfgDominatorPassTests extends AnyWordSpec with Matchers { "Have correct DOMINATE/POST_DOMINATE edges after CfgDominatorPass run." in { - val graph = OverflowDbTestInstance.create - val cpg = new Cpg(graph) + val cpg = Cpg.empty + val graph = cpg.graph - val v0 = graph + NodeTypes.METHOD - val v1 = graph + NodeTypes.UNKNOWN - val v2 = graph + NodeTypes.UNKNOWN - val v3 = graph + NodeTypes.UNKNOWN - val v4 = graph + NodeTypes.UNKNOWN - val v5 = graph + NodeTypes.UNKNOWN - val v6 = graph + NodeTypes.METHOD_RETURN + val v0 = graph.addNode(NewMethod()) + val v1 = graph.addNode(NewUnknown()) + val v2 = graph.addNode(NewUnknown()) + val v3 = graph.addNode(NewUnknown()) + val v4 = graph.addNode(NewUnknown()) + val v5 = graph.addNode(NewUnknown()) + val v6 = graph.addNode(NewMethodReturn()) - v0 --- EdgeTypes.AST --> v6 + // TODO MP get arrow syntax back +// v0 --- EdgeTypes.AST --> v6 +// +// v0 --- EdgeTypes.CFG --> v1 +// v1 --- EdgeTypes.CFG --> v2 +// v2 --- EdgeTypes.CFG --> v3 +// v2 --- EdgeTypes.CFG --> v5 +// v3 --- EdgeTypes.CFG --> v4 +// v4 --- EdgeTypes.CFG --> v2 +// v4 --- EdgeTypes.CFG --> v5 +// v5 --- EdgeTypes.CFG --> v6 + graph.applyDiff { diffGraphBuilder => + diffGraphBuilder.addEdge(v0, v6, EdgeTypes.AST) - v0 --- EdgeTypes.CFG --> v1 - v1 --- EdgeTypes.CFG --> v2 - v2 --- EdgeTypes.CFG --> v3 - v2 --- EdgeTypes.CFG --> v5 - v3 --- EdgeTypes.CFG --> v4 - v4 --- EdgeTypes.CFG --> v2 - v4 --- EdgeTypes.CFG --> v5 - v5 --- EdgeTypes.CFG --> v6 + diffGraphBuilder.addEdge(v0, v1, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v1, v2, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v2, v3, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v2, v5, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v3, v4, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v4, v5, EdgeTypes.CFG) + diffGraphBuilder.addEdge(v5, v6, EdgeTypes.CFG) + } val dominatorTreePass = new CfgDominatorPass(cpg) dominatorTreePass.createAndApply() - val v0Dominates = v0.out(EdgeTypes.DOMINATE).asScala.toList + val v0Dominates = v0.out(EdgeTypes.DOMINATE).l v0Dominates.size shouldBe 1 v0Dominates.toSet shouldBe Set(v1) - val v1Dominates = v1.out(EdgeTypes.DOMINATE).asScala.toList + val v1Dominates = v1.out(EdgeTypes.DOMINATE).l v1Dominates.size shouldBe 1 v1Dominates.toSet shouldBe Set(v2) - val v2Dominates = v2.out(EdgeTypes.DOMINATE).asScala.toList + val v2Dominates = v2.out(EdgeTypes.DOMINATE).l v2Dominates.size shouldBe 2 v2Dominates.toSet shouldBe Set(v3, v5) - val v3Dominates = v3.out(EdgeTypes.DOMINATE).asScala.toList + val v3Dominates = v3.out(EdgeTypes.DOMINATE).l v3Dominates.size shouldBe 1 v3Dominates.toSet shouldBe Set(v4) - val v4Dominates = v4.out(EdgeTypes.DOMINATE).asScala.toList + val v4Dominates = v4.out(EdgeTypes.DOMINATE).l v4Dominates.size shouldBe 0 - val v5Dominates = v5.out(EdgeTypes.DOMINATE).asScala.toList + val v5Dominates = v5.out(EdgeTypes.DOMINATE).l v5Dominates.size shouldBe 1 v5Dominates.toSet shouldBe Set(v6) - val v6Dominates = v6.out(EdgeTypes.DOMINATE).asScala.toList + val v6Dominates = v6.out(EdgeTypes.DOMINATE).l v6Dominates.size shouldBe 0 - val v6PostDominates = v6.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v6PostDominates = v6.out(EdgeTypes.POST_DOMINATE).l v6PostDominates.size shouldBe 1 v6PostDominates.toSet shouldBe Set(v5) - val v5PostDominates = v5.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v5PostDominates = v5.out(EdgeTypes.POST_DOMINATE).l v5PostDominates.size shouldBe 2 v5PostDominates.toSet shouldBe Set(v2, v4) - val v4PostDominates = v4.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v4PostDominates = v4.out(EdgeTypes.POST_DOMINATE).l v4PostDominates.size shouldBe 1 v4PostDominates.toSet shouldBe Set(v3) - val v3PostDominates = v3.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v3PostDominates = v3.out(EdgeTypes.POST_DOMINATE).l v3PostDominates.size shouldBe 0 - val v2PostDominates = v2.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v2PostDominates = v2.out(EdgeTypes.POST_DOMINATE).l v2PostDominates.size shouldBe 1 v2PostDominates.toSet shouldBe Set(v1) - val v1PostDominates = v1.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v1PostDominates = v1.out(EdgeTypes.POST_DOMINATE).l v1PostDominates.size shouldBe 1 v1PostDominates.toSet shouldBe Set(v0) - val v0PostDominates = v0.out(EdgeTypes.POST_DOMINATE).asScala.toList + val v0PostDominates = v0.out(EdgeTypes.POST_DOMINATE).l v0PostDominates.size shouldBe 0 } } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala index 329ccb09d32a..ba049a318622 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala @@ -1,14 +1,13 @@ package io.joern.x2cpg.passes -import io.shiftleft.OverflowDbTestInstance -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes} +import flatgraph.misc.TestUtils.* import io.joern.x2cpg.passes.base.ContainsEdgePass +import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewFile, NewMethod, NewTypeDecl} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes} +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ - -import scala.jdk.CollectionConverters._ class ContainsEdgePassTest extends AnyWordSpec with Matchers { @@ -16,26 +15,26 @@ class ContainsEdgePassTest extends AnyWordSpec with Matchers { "Files " can { "contain Methods" in Fixture { fixture => - fixture.methodVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.fileVertex) + fixture.methodVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.fileVertex) } "contain Classes" in Fixture { fixture => - fixture.typeDeclVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.fileVertex) + fixture.typeDeclVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.fileVertex) } } "Classes " can { "contain Methods" in Fixture { fixture => - fixture.typeMethodVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.typeDeclVertex) + fixture.typeMethodVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.typeDeclVertex) } } "Methods " can { "contain Methods" in Fixture { fixture => - fixture.innerMethodVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.methodVertex) + fixture.innerMethodVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.methodVertex) } "contain expressions" in Fixture { fixture => - fixture.expressionVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.methodVertex) - fixture.innerExpressionVertex.in(EdgeTypes.CONTAINS).asScala.toList shouldBe List(fixture.innerMethodVertex) + fixture.expressionVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.methodVertex) + fixture.innerExpressionVertex.in(EdgeTypes.CONTAINS).l shouldBe List(fixture.innerMethodVertex) } } @@ -43,23 +42,34 @@ class ContainsEdgePassTest extends AnyWordSpec with Matchers { object ContainsEdgePassTest { private class Fixture { - private val graph = OverflowDbTestInstance.create + private val cpg = Cpg.empty + private val graph = cpg.graph - val fileVertex = graph + NodeTypes.FILE - val typeDeclVertex = graph + NodeTypes.TYPE_DECL - val typeMethodVertex = graph + NodeTypes.METHOD - val methodVertex = graph + NodeTypes.METHOD - val innerMethodVertex = graph + NodeTypes.METHOD - val expressionVertex = graph + NodeTypes.CALL - val innerExpressionVertex = graph + NodeTypes.CALL + val fileVertex = graph.addNode(NewFile()) + val typeDeclVertex = graph.addNode(NewTypeDecl()) + val typeMethodVertex = graph.addNode(NewMethod()) + val methodVertex = graph.addNode(NewMethod()) + val innerMethodVertex = graph.addNode(NewMethod()) + val expressionVertex = graph.addNode(NewCall()) + val innerExpressionVertex = graph.addNode(NewCall()) - fileVertex --- EdgeTypes.AST --> typeDeclVertex - typeDeclVertex --- EdgeTypes.AST --> typeMethodVertex + // TODO MP get arrow syntax back +// fileVertex --- EdgeTypes.AST --> typeDeclVertex +// typeDeclVertex --- EdgeTypes.AST --> typeMethodVertex +// +// fileVertex --- EdgeTypes.AST --> methodVertex +// methodVertex --- EdgeTypes.AST --> innerMethodVertex +// methodVertex --- EdgeTypes.AST --> expressionVertex +// innerMethodVertex --- EdgeTypes.AST --> innerExpressionVertex + graph.applyDiff { diffGraphBuilder => + diffGraphBuilder.addEdge(fileVertex, typeDeclVertex, EdgeTypes.AST) + diffGraphBuilder.addEdge(typeDeclVertex, typeMethodVertex, EdgeTypes.AST) - fileVertex --- EdgeTypes.AST --> methodVertex - methodVertex --- EdgeTypes.AST --> innerMethodVertex - methodVertex --- EdgeTypes.AST --> expressionVertex - innerMethodVertex --- EdgeTypes.AST --> innerExpressionVertex + diffGraphBuilder.addEdge(fileVertex, methodVertex, EdgeTypes.AST) + diffGraphBuilder.addEdge(methodVertex, innerMethodVertex, EdgeTypes.AST) + diffGraphBuilder.addEdge(methodVertex, expressionVertex, EdgeTypes.AST) + diffGraphBuilder.addEdge(innerMethodVertex, innerExpressionVertex, EdgeTypes.AST) + } val containsEdgeCalculator = new ContainsEdgePass(new Cpg(graph)) containsEdgeCalculator.createAndApply() diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala index cf9a848532ef..208e89d2850c 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala @@ -1,8 +1,8 @@ package io.joern.x2cpg.passes -import io.shiftleft.codepropertygraph.generated._ +import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewMember} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala index ac281490f0ee..6977bd1456cb 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala @@ -1,30 +1,32 @@ package io.joern.x2cpg.passes -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn +import flatgraph.misc.TestUtils.* +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.passes.base.MethodDecoratorPass import io.joern.x2cpg.testfixtures.EmptyGraphFixture import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ class MethodDecoratorPassTests extends AnyWordSpec with Matchers { "MethodDecoratorTest" in EmptyGraphFixture { graph => - val method = graph + NodeTypes.METHOD - val parameterIn = graph - .+( - NodeTypes.METHOD_PARAMETER_IN, - Properties.Code -> "p1", - Properties.Order -> 1, - Properties.Name -> "p1", - Properties.EvaluationStrategy -> EvaluationStrategies.BY_REFERENCE, - Properties.TypeFullName -> "some.Type", - Properties.LineNumber -> 10 - ) - .asInstanceOf[MethodParameterIn] + val method = graph.addNode(NewMethod()) + val parameterIn = graph.addNode( + NewMethodParameterIn() + .code("p1") + .order(1) + .name("p1") + .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) + .typeFullName("some.Type") + .lineNumber(10) + ) - method --- EdgeTypes.AST --> parameterIn + // TODO MP get arrow syntax back +// method --- EdgeTypes.AST --> parameterIn + graph.applyDiff { diffGraphBuilder => + diffGraphBuilder.addEdge(method, parameterIn, EdgeTypes.AST) + } val methodDecorator = new MethodDecoratorPass(new Cpg(graph)) methodDecorator.createAndApply() diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala index 63e0d7ec1e86..2f22f3c62b2e 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala @@ -1,22 +1,22 @@ package io.joern.x2cpg.passes -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.{NodeTypes, Properties} -import io.shiftleft.semanticcpg.language._ +import flatgraph.misc.TestUtils.addNode +import io.shiftleft.codepropertygraph.generated.{Cpg, NodeTypes} +import io.shiftleft.semanticcpg.language.* import io.joern.x2cpg.passes.base.NamespaceCreator import io.joern.x2cpg.testfixtures.EmptyGraphFixture +import io.shiftleft.codepropertygraph.generated.nodes.NewNamespaceBlock import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb._ class NamespaceCreatorTests extends AnyWordSpec with Matchers { "NamespaceCreateor test " in EmptyGraphFixture { graph => val cpg = new Cpg(graph) - val block1 = graph + (NodeTypes.NAMESPACE_BLOCK, Properties.Name -> "namespace1") - val block2 = graph + (NodeTypes.NAMESPACE_BLOCK, Properties.Name -> "namespace1") - val block3 = graph + (NodeTypes.NAMESPACE_BLOCK, Properties.Name -> "namespace2") + val block1 = graph.addNode(NewNamespaceBlock().name("namespace1")) + val block2 = graph.addNode(NewNamespaceBlock().name("namespace1")) + val block3 = graph.addNode(NewNamespaceBlock().name("namespace2")) - val namespaceCreator = new NamespaceCreator(new Cpg(graph)) + val namespaceCreator = new NamespaceCreator(cpg) namespaceCreator.createAndApply() val namespaces = cpg.namespace.l diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala index 14c57a6fb97f..b8292db40bb4 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala @@ -4,7 +4,7 @@ import io.joern.x2cpg.passes.controlflow.CfgCreationPass import io.joern.x2cpg.passes.controlflow.cfgcreation.Cfg.CfgEdgeType import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{CfgNode, Method} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* abstract class CfgTestCpg extends TestCpg { override protected def applyPasses(): Unit = { @@ -30,7 +30,7 @@ class CfgTestFixture[T <: CfgTestCpg](testCpgFactory: () => T) extends Code2CpgF ExpectationInfo(pair._1, pair._2, pair._3) } - def expected(pairs: ExpectationInfo*)(implicit cpg: Cpg): Set[String] = { + def expected(pairs: ExpectationInfo*)(implicit cpg: Cpg): List[String] = { pairs.map { case ExpectationInfo(code, index, _) => cpg.method.ast.isCfgNode.toVector .collect { @@ -38,11 +38,11 @@ class CfgTestFixture[T <: CfgTestCpg](testCpgFactory: () => T) extends Code2CpgF } .lift(index) .getOrElse(fail(s"No node found for code = '$code' and index '$index'!")) - }.toSet + }.toList } // index is zero based and describes which node to take if multiple node match the code string. - def succOf(code: String, index: Int = 0)(implicit cpg: Cpg): Set[String] = { + def succOf(code: String, index: Int = 0)(implicit cpg: Cpg): List[String] = { cpg.method.ast.isCfgNode.toVector .collect { case node if matchCode(node, code) => node @@ -52,10 +52,10 @@ class CfgTestFixture[T <: CfgTestCpg](testCpgFactory: () => T) extends Code2CpgF ._cfgOut .cast[CfgNode] .code - .toSetImmutable + .toList } - def succOf(code: String, nodeType: String)(implicit cpg: Cpg): Set[String] = { + def succOf(code: String, nodeType: String)(implicit cpg: Cpg): List[String] = { cpg.method.ast.isCfgNode .label(nodeType) .toVector @@ -66,6 +66,6 @@ class CfgTestFixture[T <: CfgTestCpg](testCpgFactory: () => T) extends Code2CpgF ._cfgOut .cast[CfgNode] .code - .toSetImmutable + .toList } } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala index 4a378c095580..36c6dcb2c481 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala @@ -1,12 +1,11 @@ package io.joern.x2cpg.testfixtures -import io.shiftleft.OverflowDbTestInstance -import overflowdb.Graph +import flatgraph.Graph +import io.shiftleft.codepropertygraph.generated.Cpg + +import scala.util.Using object EmptyGraphFixture { - def apply[T](fun: Graph => T): T = { - val graph = OverflowDbTestInstance.create - try fun(graph) - finally { graph.close() } - } + def apply[T](fun: Graph => T): T = + Using.resource(Cpg.empty.graph)(fun) } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala index b70cddfb4c56..7dac8f7b7bee 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala @@ -1,9 +1,9 @@ package io.joern.x2cpg.testfixtures +import flatgraph.Graph import io.joern.x2cpg.X2CpgConfig import io.joern.x2cpg.utils.TestCodeWriter import io.shiftleft.codepropertygraph.generated.Cpg -import overflowdb.Graph import java.nio.file.{Files, Path} import java.util.Comparator @@ -11,7 +11,7 @@ import java.util.Comparator // Lazily populated test CPG which is created upon first access to the underlying graph. // The trait LanguageFrontend is mixed in and not property/field of this class in order // to allow the configuration of language frontend specific properties on the CPG object. -abstract class TestCpg extends Cpg() with LanguageFrontend with TestCodeWriter { +abstract class TestCpg extends Cpg(Cpg.empty.graph) with LanguageFrontend with TestCodeWriter { private var _graph = Option.empty[Graph] protected var _withPostProcessing = false diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala index 006af0bfff99..ef89b57fd2b3 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala @@ -4,15 +4,41 @@ import better.files.File import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import scala.util.Properties.isWin import scala.util.{Failure, Success} class ExternalCommandTest extends AnyWordSpec with Matchers { + private def cwd = File.currentWorkingDirectory.pathAsString + "ExternalCommand.run" should { "be able to run `ls` successfully" in { File.usingTemporaryDirectory("sample") { sourceDir => - val cmd = "ls " + sourceDir.pathAsString - ExternalCommand.run(cmd, sourceDir.pathAsString) should be a Symbol("success") + val cmd = Seq("ls", sourceDir.pathAsString) + ExternalCommand.run(cmd, sourceDir.pathAsString).toTry should be a Symbol("success") + } + } + + "report exit code and stdout/stderr for nonzero exit code" in { + ExternalCommand.run(Seq("ls", "/does/not/exist"), cwd).toTry match { + case result: Success[_] => + fail(s"expected failure, but got $result") + case Failure(exception) => + exception.getMessage should include("Process exited with code") // exit code `2` on linux, `1` on mac... + exception.getMessage should include("No such file or directory") // again, different errors on mac and linux + } + } + + "report error for io exception (e.g. for nonexisting command)" in { + ExternalCommand.run(Seq("/command/does/not/exist"), cwd).toTry match { + case result: Success[_] => + fail(s"expected failure, but got $result") + case Failure(exception) => + exception.getMessage should include("""Cannot run program "/command/does/not/exist"""") + if (isWin) + exception.getMessage should include("The system cannot find the file") + else + exception.getMessage should include("No such file or directory") } } } diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/KeyPoolTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/KeyPoolTests.scala new file mode 100644 index 000000000000..4815c4827332 --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/KeyPoolTests.scala @@ -0,0 +1,74 @@ +package io.joern.x2cpg.utils + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class KeyPoolTests extends AnyWordSpec with Matchers { + + "IntervalKeyPool" should { + "return [first, ..., last] and then raise" in { + val keyPool = new IntervalKeyPool(10, 19) + List.range(0, 10).map(_ => keyPool.next) shouldBe List.range(10, 20) + assertThrows[RuntimeException] { keyPool.next } + assertThrows[RuntimeException] { keyPool.next } + } + + "allow splitting into multiple pools" in { + val keyPool = new IntervalKeyPool(1, 1000) + val pools = keyPool.split(11).toList + assertThrows[IllegalStateException] { keyPool.next } + pools.size shouldBe 11 + // Pools should all have the same size + pools + .map { x => + (x.last - x.first) + } + .distinct + .size shouldBe 1 + // Pools should be pairwise disjoint + val keySets = pools.map { x => + (x.first to x.last).toSet + } + keySets.combinations(2).foreach { + case List(x: Set[Long], y: Set[Long]) => + x.intersect(y).isEmpty shouldBe true + case _ => + fail() + } + } + + "return empty iterator when asked to create 0 partitions" in { + val keyPool = new IntervalKeyPool(1, 1000) + keyPool.split(0).hasNext shouldBe false + } + + } + + "SequenceKeyPool" should { + "return elements of sequence one by one and then raise" in { + val seq = List[Long](1, 2, 3) + val keyPool = new SequenceKeyPool(seq) + List.range(0, 3).map(_ => keyPool.next) shouldBe seq + assertThrows[RuntimeException] { keyPool.next } + assertThrows[RuntimeException] { keyPool.next } + } + } + + "KeyPoolCreator" should { + "split into n pools and honor minimum value" in { + val minValue = 10 + val pools = KeyPoolCreator.obtain(3, minValue) + pools.size shouldBe 3 + pools match { + case List(pool1, pool2, pool3) => + pool1.first shouldBe minValue + pool1.last should be < pool2.first + pool2.last should be < pool3.first + pool3.last shouldBe Long.MaxValue - 1 + case _ => fail() + } + } + + } + +} diff --git a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala index 58bde517e1bd..0a8c79b14727 100644 --- a/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala +++ b/joern-cli/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala @@ -31,7 +31,7 @@ class DependencyResolverTests extends AnyWordSpec with Matchers { "test maven dependency resolution" ignore { // check that `mvn` is available - otherwise test will fail with only some logged warnings... withClue("`mvn` must be installed in order for this test to work...") { - ExternalCommand.run("mvn --version", ".").get.exists(_.contains("Apache Maven")) shouldBe true + ExternalCommand.run(Seq("mvn", "--version"), ".").successOption.exists(_.contains("Apache Maven")) shouldBe true } @nowarn // otherwise scalac warns that this might be an interpolated expression diff --git a/joern-cli/src/main/resources/scripts/trigger-error.sc b/joern-cli/src/main/resources/scripts/trigger-error.sc index bbc2174fb90d..1943cbae22d0 100644 --- a/joern-cli/src/main/resources/scripts/trigger-error.sc +++ b/joern-cli/src/main/resources/scripts/trigger-error.sc @@ -1 +1,2 @@ -assert(true == false, "trigger an error for testing purposes") +import scala.util.control.NoStackTrace +throw new Exception("triggering an error for testing purposes") with NoStackTrace diff --git a/joern-cli/src/main/scala/io/joern/joerncli/CpgBasedTool.scala b/joern-cli/src/main/scala/io/joern/joerncli/CpgBasedTool.scala index d40ae3550443..274920f8a7d4 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/CpgBasedTool.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/CpgBasedTool.scala @@ -4,22 +4,23 @@ import better.files.File import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.cpgloading.CpgLoaderConfig import io.shiftleft.semanticcpg.layers.LayerCreatorContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.cpgloading.CpgLoader object CpgBasedTool { + def loadFromFile(filename: String): Cpg = + CpgLoader.load(filename) + /** Load code property graph from overflowDB * * @param filename * name of the file that stores the CPG */ - def loadFromOdb(filename: String): Cpg = { - val odbConfig = overflowdb.Config.withDefaults().withStorageLocation(filename) - val config = CpgLoaderConfig().withOverflowConfig(odbConfig).doNotCreateIndexesOnLoad - io.shiftleft.codepropertygraph.cpgloading.CpgLoader.loadFromOverflowDb(config) - } + @deprecated("use `loadFromFile` instead", "joern v3") + def loadFromOdb(filename: String): Cpg = + loadFromFile(filename) /** Add the data flow layer to the CPG if it does not exist yet. */ diff --git a/joern-cli/src/main/scala/io/joern/joerncli/DefaultOverlays.scala b/joern-cli/src/main/scala/io/joern/joerncli/DefaultOverlays.scala index 4cc406f104d5..3adc182a4eb3 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/DefaultOverlays.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/DefaultOverlays.scala @@ -3,7 +3,7 @@ package io.joern.joerncli import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.x2cpg.X2Cpg.applyDefaultOverlays import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.layers._ +import io.shiftleft.semanticcpg.layers.* object DefaultOverlays { @@ -16,7 +16,7 @@ object DefaultOverlays { * the filename of the cpg */ def create(storeFilename: String, maxNumberOfDefinitions: Int = defaultMaxNumberOfDefinitions): Cpg = { - val cpg = CpgBasedTool.loadFromOdb(storeFilename) + val cpg = CpgBasedTool.loadFromFile(storeFilename) applyDefaultOverlays(cpg) val context = new LayerCreatorContext(cpg) val options = new OssDataFlowOptions(maxNumberOfDefinitions) diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernExport.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernExport.scala index 6b0b78612602..7bea6d43f6aa 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernExport.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernExport.scala @@ -2,21 +2,21 @@ package io.joern.joerncli import better.files.Dsl.* import better.files.File +import flatgraph.{Accessors, Edge, GNode} +import flatgraph.formats.ExportResult +import flatgraph.formats.dot.DotExporter +import flatgraph.formats.graphml.GraphMLExporter +import flatgraph.formats.graphson.GraphSONExporter +import flatgraph.formats.neo4jcsv.Neo4jCsvExporter import io.joern.dataflowengineoss.DefaultSemantics import io.joern.dataflowengineoss.layers.dataflows.* -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.{NoSemantics, Semantics} import io.joern.joerncli.CpgBasedTool.exitIfInvalid import io.joern.x2cpg.layers.* import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.NodeTypes -import io.shiftleft.semanticcpg.language.{toAstNodeMethods, toNodeTypeStarters} +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.* -import overflowdb.formats.ExportResult -import overflowdb.formats.dot.DotExporter -import overflowdb.formats.graphml.GraphMLExporter -import overflowdb.formats.graphson.GraphSONExporter -import overflowdb.formats.neo4jcsv.Neo4jCsvExporter -import overflowdb.{Edge, Node} import java.nio.file.{Path, Paths} import scala.collection.mutable @@ -64,7 +64,7 @@ object JoernExport { exitIfInvalid(outDir, config.cpgFileName) mkdir(File(outDir)) - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + Using.resource(CpgBasedTool.loadFromFile(config.cpgFileName)) { cpg => exportCpg(cpg, config.repr, config.format, Paths.get(outDir).toAbsolutePath) } } @@ -96,7 +96,7 @@ object JoernExport { def exportCpg(cpg: Cpg, representation: Representation.Value, format: Format.Value, outDir: Path): Unit = { implicit val semantics: Semantics = DefaultSemantics() - if (semantics.elements.isEmpty) { + if (semantics == NoSemantics) { System.err.println("Warning: semantics are empty.") } @@ -105,15 +105,15 @@ object JoernExport { format match { case Format.Dot if representation == Representation.All || representation == Representation.Cpg => - exportWithOdbFormat(cpg, representation, outDir, DotExporter) + exportWithFlatgraphFormat(cpg, representation, outDir, DotExporter) case Format.Dot => exportDot(representation, outDir, context) case Format.Neo4jCsv => - exportWithOdbFormat(cpg, representation, outDir, Neo4jCsvExporter) + exportWithFlatgraphFormat(cpg, representation, outDir, Neo4jCsvExporter) case Format.Graphml => - exportWithOdbFormat(cpg, representation, outDir, GraphMLExporter) + exportWithFlatgraphFormat(cpg, representation, outDir, GraphMLExporter) case Format.Graphson => - exportWithOdbFormat(cpg, representation, outDir, GraphSONExporter) + exportWithFlatgraphFormat(cpg, representation, outDir, GraphSONExporter) case other => throw new NotImplementedError(s"repr=$representation not yet supported for format=$format") } @@ -133,11 +133,11 @@ object JoernExport { } } - private def exportWithOdbFormat( + private def exportWithFlatgraphFormat( cpg: Cpg, repr: Representation.Value, outDir: Path, - exporter: overflowdb.formats.Exporter + exporter: flatgraph.formats.Exporter ): Unit = { val ExportResult(nodeCount, edgeCount, _, additionalInfo) = repr match { case Representation.All => @@ -154,7 +154,7 @@ object JoernExport { windowsFilenameDeduplicationHelper ) val outFileName = outDir.resolve(relativeFilename) - exporter.runExport(nodes, subGraph.edges, outFileName) + exporter.runExport(cpg.graph.schema, nodes, subGraph.edges, outFileName) } .reduce(plus) } else { @@ -220,12 +220,12 @@ object JoernExport { private def emptyExportResult = ExportResult(0, 0, Seq.empty, Option("Empty CPG")) - case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[Node]) { + case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[GNode]) { def edges: Set[Edge] = { for { node <- nodes - edge <- node.bothE.asScala - if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) + edge <- Accessors.getEdgesOut(node) + if nodes.contains(edge.dst) } yield edge } } diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernFlow.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernFlow.scala index 07cf60a4ef28..211bcd1ec0b4 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernFlow.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernFlow.scala @@ -28,7 +28,7 @@ object JoernFlow { } debugOut("Loading graph... ") - val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName) + val cpg = CpgBasedTool.loadFromFile(config.cpgFileName) debugOut("[DONE]\n") implicit val resolver: ICallResolver = NoResolve diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernParse.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernParse.scala index 34e2ff7a2889..1bfed24fc7f6 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernParse.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernParse.scala @@ -4,15 +4,14 @@ import better.files.File import io.joern.console.cpgcreation.{CpgGenerator, cpgGeneratorForLanguage, guessLanguage} import io.joern.console.{FrontendConfig, InstallConfig} import io.joern.joerncli.CpgBasedTool.newCpgCreatedString +import io.joern.x2cpg.frontendspecific.FrontendArgsDelimitor import io.shiftleft.codepropertygraph.generated.Languages import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.{Failure, Success, Try} object JoernParse { - // Special string used to separate joern-parse opts from frontend-specific opts - val ArgsDelimitor = "--frontend-args" val DefaultCpgOutFile = "cpg.bin" var generator: CpgGenerator = scala.compiletime.uninitialized @@ -64,7 +63,7 @@ object JoernParse { note("Misc") help("help").text("display this help message") - note(s"Args specified after the $ArgsDelimitor separator will be passed to the front-end verbatim") + note(s"Args specified after the $FrontendArgsDelimitor separator will be passed to the front-end verbatim") } private def run(args: Array[String]): Try[String] = { diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernScan.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernScan.scala index c5728589c15a..ae08472490c4 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernScan.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernScan.scala @@ -4,18 +4,19 @@ import better.files.* import io.joern.console.scan.{ScanPass, outputFindings} import io.joern.console.{BridgeBase, DefaultArgumentProvider, Query, QueryDatabase} import io.joern.dataflowengineoss.queryengine.{EngineConfig, EngineContext} -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.{Semantics, NoSemantics} import io.joern.joerncli.JoernScan.getQueriesFromQueryDb import io.joern.joerncli.Scan.{allTag, defaultTag} import io.joern.joerncli.console.ReplBridge import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.semanticcpg.language.{DefaultNodeExtensionFinder, NodeExtensionFinder} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -import java.io.PrintStream + import org.json4s.native.Serialization import org.json4s.{Formats, NoTypeHints} + import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* object JoernScanConfig { val defaultDbVersion: String = "latest" @@ -128,7 +129,7 @@ object JoernScan extends BridgeBase { } private def dumpQueriesAsJson(outFileName: String): Unit = { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val formats: AnyRef & Formats = Serialization.formats(NoTypeHints) val queryDb = new QueryDatabase(new JoernDefaultArgumentProvider(0)) better.files @@ -179,7 +180,7 @@ object JoernScan extends BridgeBase { } private def queryNames(): List[String] = { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) getQueriesFromQueryDb(new JoernDefaultArgumentProvider(0)).map(_.name) } @@ -237,12 +238,10 @@ object JoernScan extends BridgeBase { } } - override protected def predefLines = ReplBridge.predefLines - override protected def promptStr = ReplBridge.promptStr - - override protected def greeting = ReplBridge.greeting - - override protected def onExitCode = ReplBridge.onExitCode + override protected def runBeforeCode = ReplBridge.runBeforeCode + override protected def promptStr = ReplBridge.promptStr + override protected def greeting = ReplBridge.greeting + override protected def onExitCode = ReplBridge.onExitCode } object Scan { @@ -274,7 +273,7 @@ class Scan(options: ScanOptions)(implicit engineContext: EngineContext) extends println("No queries matched current filter selection (total number of queries: `" + allQueries.length + "`)") return } - runPass(new ScanPass(context.cpg, queriesAfterFilter), context) + ScanPass(context.cpg, queriesAfterFilter).createAndApply() outputFindings(context.cpg) } diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernSlice.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernSlice.scala index c16f09eb428a..453d370eed5b 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernSlice.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernSlice.scala @@ -122,7 +122,7 @@ object JoernSlice { } else { config.inputPath.pathAsString } - Using.resource(CpgBasedTool.loadFromOdb(inputCpgPath)) { cpg => + Using.resource(CpgBasedTool.loadFromFile(inputCpgPath)) { cpg => checkAndApplyOverlays(cpg) // Slice the CPG (config match { diff --git a/joern-cli/src/main/scala/io/joern/joerncli/JoernVectors.scala b/joern-cli/src/main/scala/io/joern/joerncli/JoernVectors.scala index 57ba6673bcc0..d32bceb28f33 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/JoernVectors.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/JoernVectors.scala @@ -15,11 +15,10 @@ import scala.util.hashing.MurmurHash3 class BagOfPropertiesForNodes extends EmbeddingGenerator[AstNode, (String, String)] { override def structureToString(pair: (String, String)): String = pair._1 + ":" + pair._2 - override def extractObjects(cpg: Cpg): Iterator[AstNode] = cpg.graph.V.collect { case x: AstNode => x } + override def extractObjects(cpg: Cpg): Iterator[AstNode] = cpg.astNode override def enumerateSubStructures(obj: AstNode): List[(String, String)] = { val relevantFieldTypes = Set(PropertyNames.NAME, PropertyNames.FULL_NAME, PropertyNames.CODE) - val relevantFields = obj - .propertiesMap() + val relevantFields = obj.propertiesMap .entrySet() .asScala .toList @@ -136,7 +135,7 @@ object JoernVectors { def main(args: Array[String]) = { parseConfig(args).foreach { config => exitIfInvalid(config.outDir, config.cpgFileName) - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + Using.resource(CpgBasedTool.loadFromFile(config.cpgFileName)) { cpg => val generator = new BagOfPropertiesForNodes() val embedding = generator.embed(cpg) println("{") @@ -150,8 +149,8 @@ object JoernVectors { traversalToJson(embedding.vectors, generator.vectorToString) println(",\"edges\":") traversalToJson( - cpg.graph.edges().map { x => - Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label()) + cpg.graph.allEdges.map { edge => + Map("src" -> edge.src.id, "dst" -> edge.dst.id, "label" -> edge.label) }, generator.defaultToString ) diff --git a/joern-cli/src/main/scala/io/joern/joerncli/console/JoernConsole.scala b/joern-cli/src/main/scala/io/joern/joerncli/console/JoernConsole.scala index 93b7ce1a7bc5..607e9a3074aa 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/console/JoernConsole.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/console/JoernConsole.scala @@ -1,6 +1,6 @@ package io.joern.joerncli.console -import better.files._ +import better.files.* import io.joern.console.defaultAvailableWidthProvider import io.joern.console.workspacehandling.{ProjectFile, WorkspaceLoader} import io.joern.console.{Console, ConsoleConfig, InstallConfig} diff --git a/joern-cli/src/main/scala/io/joern/joerncli/console/JoernProject.scala b/joern-cli/src/main/scala/io/joern/joerncli/console/JoernProject.scala index df1bd9fb86e9..ecfaa2f7dd96 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/console/JoernProject.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/console/JoernProject.scala @@ -2,7 +2,7 @@ package io.joern.joerncli.console import io.joern.console.workspacehandling.{Project, ProjectFile} import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import io.shiftleft.codepropertygraph.generated.Cpg import java.nio.file.Path @@ -11,5 +11,5 @@ class JoernProject( projectFile: ProjectFile, path: Path, cpg: Option[Cpg] = None, - var context: EngineContext = EngineContext(Semantics.empty) + var context: EngineContext = EngineContext(NoSemantics) ) extends Project(projectFile, path, cpg) {} diff --git a/joern-cli/src/main/scala/io/joern/joerncli/console/Predefined.scala b/joern-cli/src/main/scala/io/joern/joerncli/console/Predefined.scala deleted file mode 100644 index f3107c8ff907..000000000000 --- a/joern-cli/src/main/scala/io/joern/joerncli/console/Predefined.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.joern.joerncli.console - -import io.joern.console.{Help, Run} - -object Predefined { - - val shared: Seq[String] = - Seq( - "import _root_.io.joern.console._", - "import _root_.io.joern.joerncli.console.JoernConsole._", - "import _root_.io.shiftleft.codepropertygraph.Cpg.docSearchPackages", - "import _root_.io.shiftleft.codepropertygraph.generated.Cpg", - "import _root_.io.shiftleft.codepropertygraph.cpgloading._", - "import _root_.io.shiftleft.codepropertygraph.generated._", - "import _root_.io.shiftleft.codepropertygraph.generated.nodes._", - "import _root_.io.shiftleft.codepropertygraph.generated.edges._", - "import _root_.io.joern.dataflowengineoss.language._", - "import _root_.io.shiftleft.semanticcpg.language._", - "import overflowdb._", - "import overflowdb.traversal.{`package` => _, help => _, _}", - "import scala.jdk.CollectionConverters._", - "implicit val resolver: ICallResolver = NoResolve", - "implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder" - ) - - val forInteractiveShell: Seq[String] = { - shared ++ - Seq("import _root_.io.joern.joerncli.console.Joern._") ++ - Run.codeForRunCommand().linesIterator ++ - Help.codeForHelpCommand(classOf[io.joern.joerncli.console.JoernConsole]).linesIterator ++ - Seq("ossDataFlowOptions = opts.ossdataflow") - } - -} diff --git a/joern-cli/src/main/scala/io/joern/joerncli/console/ReplBridge.scala b/joern-cli/src/main/scala/io/joern/joerncli/console/ReplBridge.scala index 8b774e5e96eb..bf3fabff8c52 100644 --- a/joern-cli/src/main/scala/io/joern/joerncli/console/ReplBridge.scala +++ b/joern-cli/src/main/scala/io/joern/joerncli/console/ReplBridge.scala @@ -1,7 +1,6 @@ package io.joern.joerncli.console import io.joern.console.BridgeBase -import java.io.PrintStream object ReplBridge extends BridgeBase { @@ -13,8 +12,7 @@ object ReplBridge extends BridgeBase { /** Code that is executed when starting the shell */ - override def predefLines = - Predefined.forInteractiveShell + override def runBeforeCode = RunBeforeCode.forInteractiveShell override def greeting = JoernConsole.banner() diff --git a/joern-cli/src/main/scala/io/joern/joerncli/console/RunBeforeCode.scala b/joern-cli/src/main/scala/io/joern/joerncli/console/RunBeforeCode.scala new file mode 100644 index 000000000000..0d315c64325e --- /dev/null +++ b/joern-cli/src/main/scala/io/joern/joerncli/console/RunBeforeCode.scala @@ -0,0 +1,31 @@ +package io.joern.joerncli.console + +import io.joern.console.{Help, Run} + +object RunBeforeCode { + + val shared: Seq[String] = + Seq( + "import _root_.io.joern.console.*", + "import _root_.io.joern.joerncli.console.JoernConsole.*", + "import _root_.io.shiftleft.codepropertygraph.cpgloading.*", + "import _root_.io.shiftleft.codepropertygraph.generated.{help => _, _}", + "import _root_.io.shiftleft.codepropertygraph.generated.nodes.*", + "import _root_.io.joern.dataflowengineoss.language.*", + "import _root_.io.shiftleft.semanticcpg.language.*", + "import scala.jdk.CollectionConverters.*", + "import _root_.io.shiftleft.semanticcpg.sarif.SarifConfig", + "implicit val resolver: ICallResolver = NoResolve", + "implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder", + "implicit val sarifConfig: SarifConfig = SarifConfig(semanticVersion = Option(version))" + ) + + val forInteractiveShell: Seq[String] = { + shared ++ + Seq("import _root_.io.joern.joerncli.console.Joern.*") ++ + Run.codeForRunCommand().linesIterator ++ + Help.codeForHelpCommand(classOf[io.joern.joerncli.console.JoernConsole]).linesIterator ++ + Seq("ossDataFlowOptions = opts.ossdataflow") + } + +} diff --git a/joern-cli/src/test/resources/additional-import.sc b/joern-cli/src/test/resources/additional-import.sc new file mode 100644 index 000000000000..a782a1d75a16 --- /dev/null +++ b/joern-cli/src/test/resources/additional-import.sc @@ -0,0 +1,2 @@ +def sayHello(to: String) = + s"hello, $to" \ No newline at end of file diff --git a/joern-cli/src/test/scala/io/joern/joerncli/GenerationTests.scala b/joern-cli/src/test/scala/io/joern/joerncli/GenerationTests.scala index 6cf0a8e14e88..3f6ae67c4ba8 100644 --- a/joern-cli/src/test/scala/io/joern/joerncli/GenerationTests.scala +++ b/joern-cli/src/test/scala/io/joern/joerncli/GenerationTests.scala @@ -1,7 +1,7 @@ package io.joern.joerncli import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/joern-cli/src/test/scala/io/joern/joerncli/RunScriptTests.scala b/joern-cli/src/test/scala/io/joern/joerncli/RunScriptTests.scala index 8b1b4b36d1cf..47e568446456 100644 --- a/joern-cli/src/test/scala/io/joern/joerncli/RunScriptTests.scala +++ b/joern-cli/src/test/scala/io/joern/joerncli/RunScriptTests.scala @@ -1,5 +1,7 @@ package io.joern.joerncli +import better.files._ +import java.nio.file.Paths import io.joern.console.Config import io.joern.joerncli.console.ReplBridge import io.shiftleft.utils.ProjectRoot @@ -27,6 +29,114 @@ class RunScriptTests extends AnyWordSpec with Matchers { } } + "execute a simple script" in new Fixture { + def test(scriptFile: File, outputFile: File) = { + val escScriptPath = outputFile.pathAsString.replace("\\", "\\\\") + scriptFile.write(s""" + |val fw = new java.io.FileWriter("$escScriptPath", true) + |fw.write("michael was here") + |fw.close() + """.stripMargin) + + ReplBridge.main(Array("--script", scriptFile.pathAsString)) + + withClue(s"$outputFile content: ") { + outputFile.lines.head shouldBe "michael was here" + } + } + } + + "pass parameters to script" in new Fixture { + def test(scriptFile: File, outputFile: File) = { + scriptFile.write(s""" + |@main def foo(outFile: String, magicNumber: Int) = { + | val fw = new java.io.FileWriter(outFile, true) + | fw.write(magicNumber.toString) + | fw.close() + |} + """.stripMargin) + + ReplBridge.main( + Array( + "--script", + scriptFile.pathAsString, + "--param", + s"outFile=${outputFile.pathAsString}", + "--param", + "magicNumber=42" + ) + ) + + withClue(s"$outputFile content: ") { + outputFile.lines.head shouldBe "42" + } + } + } + + "script with multiple @main methods" in new Fixture { + def test(scriptFile: File, outputFile: File) = { + val escScriptPath = outputFile.pathAsString.replace("\\", "\\\\") + + scriptFile.write(s""" + |@main def foo() = { + | val fw = new java.io.FileWriter("$escScriptPath", true) + | fw.write("foo was called") + | fw.close() + |} + |@main def bar() = { + | val fw = new java.io.FileWriter("$escScriptPath", true) + | fw.write("bar was called") + | fw.close() + |} + """.stripMargin) + + ReplBridge.main(Array("--script", scriptFile.pathAsString, "--command", "bar")) + + withClue(s"$outputFile content: ") { + outputFile.lines.head shouldBe "bar was called" + } + } + } + + "use additional import script: //> using file directive" in new Fixture { + def test(scriptFile: File, outputFile: File) = { + val escScriptPath = outputFile.pathAsString.replace("\\", "\\\\") + val additionalImportFile = Paths.get("joern-cli/src/test/resources/additional-import.sc").toAbsolutePath + + scriptFile.write(s""" + |//> using file $additionalImportFile + |val fw = new java.io.FileWriter("$escScriptPath", true) + |fw.write(sayHello("michael")) //function defined in additionalImportFile + |fw.close() + """.stripMargin) + + ReplBridge.main(Array("--script", scriptFile.pathAsString)) + + withClue(s"$outputFile content: ") { + outputFile.lines.head shouldBe "hello, michael" + } + } + } + + "use additional import script: --import parameter" in new Fixture { + def test(scriptFile: File, outputFile: File) = { + val escScriptPath = outputFile.pathAsString.replace("\\", "\\\\") + val additionalImportFile = Paths.get("joern-cli/src/test/resources/additional-import.sc").toAbsolutePath + + scriptFile.write(s""" + |val fw = new java.io.FileWriter("$escScriptPath", true) + |fw.write(sayHello("michael")) //function defined in additionalImportFile + |fw.close() + """.stripMargin) + + ReplBridge.main(Array("--script", scriptFile.pathAsString, "--import", additionalImportFile.toString)) + + withClue(s"$outputFile content: ") { + outputFile.lines.head shouldBe "hello, michael" + } + } + } + "should return Failure if" when { "script doesn't exist" in { val result = ReplBridge.runScript(Config(scriptFile = Some(scriptsRoot.resolve("does-not-exist.sc")))) @@ -53,4 +163,12 @@ object RunScriptTests { ) .get } + + trait Fixture { + def test(scriptFile: File, outputFile: File): Unit + for { + scriptFile <- File.temporaryFile() + outputFile <- File.temporaryFile() + } test(scriptFile, outputFile) + } } diff --git a/joern-cli/src/universal/schema-extender/build.sbt b/joern-cli/src/universal/schema-extender/build.sbt index 7632c73b0f74..af3209b2a74a 100644 --- a/joern-cli/src/universal/schema-extender/build.sbt +++ b/joern-cli/src/universal/schema-extender/build.sbt @@ -1,10 +1,10 @@ name := "schema-extender" -ThisBuild / scalaVersion := "3.4.1" +ThisBuild / scalaVersion := "3.5.2" val cpgVersion = IO.read(file("cpg-version")) -val generateDomainClasses = taskKey[Seq[File]]("generate overflowdb domain classes for our schema") +val generateDomainClasses = taskKey[Seq[File]]("generate domain classes for our schema") val joernInstallPath = settingKey[String]("path to joern installation, e.g. `/home/username/bin/joern/joern-cli` or `../../joern/joern-cli`") @@ -33,9 +33,9 @@ ThisBuild / libraryDependencies ++= Seq( lazy val schema = project .in(file("schema")) .settings(generateDomainClasses := { - val outputRoot = target.value / "odb-codegen" + val outputRoot = target.value / "fg-codegen" FileUtils.deleteRecursively(outputRoot) - val invoked = (Compile / runMain).toTask(s" CpgExtCodegen schema/target/odb-codegen").value + val invoked = (Compile / runMain).toTask(s" CpgExtCodegen schema/target/fg-codegen").value FileUtils.listFilesRecursively(outputRoot) }) diff --git a/joern-cli/src/universal/schema-extender/project/FileUtils.scala b/joern-cli/src/universal/schema-extender/project/FileUtils.scala index dc61c44afee2..4bd8aaae7980 100644 --- a/joern-cli/src/universal/schema-extender/project/FileUtils.scala +++ b/joern-cli/src/universal/schema-extender/project/FileUtils.scala @@ -1,6 +1,6 @@ import java.io.File import java.nio.file.Files -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.* object FileUtils { diff --git a/joern-cli/src/universal/schema-extender/schema/src/main/scala/CpgExtCodegen.scala b/joern-cli/src/universal/schema-extender/schema/src/main/scala/CpgExtCodegen.scala index a94cfedd5857..5bdd11cdac44 100644 --- a/joern-cli/src/universal/schema-extender/schema/src/main/scala/CpgExtCodegen.scala +++ b/joern-cli/src/universal/schema-extender/schema/src/main/scala/CpgExtCodegen.scala @@ -1,14 +1,13 @@ -import io.shiftleft.codepropertygraph.schema._ -import overflowdb.codegen.CodeGen -import overflowdb.schema.SchemaBuilder -import overflowdb.schema.Property.ValueType - -import java.io.File +import io.shiftleft.codepropertygraph.schema.* +import flatgraph.codegen.DomainClassesGenerator +import flatgraph.schema.SchemaBuilder +import flatgraph.schema.Property.ValueType +import java.nio.file.Paths object CpgExtCodegen { def main(args: Array[String]): Unit = { val outputDir = args.headOption - .map(new File(_)) + .map(Paths.get(_)) .getOrElse(throw new AssertionError("please pass outputDir as first parameter")) val builder = new SchemaBuilder(domainShortName = "Cpg", basePackage = "io.shiftleft.codepropertygraph.generated") @@ -24,6 +23,6 @@ object CpgExtCodegen { cpgSchema.fs.file.addProperties(exampleProperty) // END extensions for this build - new CodeGen(builder.build).run(outputDir) + new DomainClassesGenerator(builder.build).run(outputDir) } } diff --git a/joern-cli/src/universal/schema-extender/test.sh b/joern-cli/src/universal/schema-extender/test.sh index 309ab80b710a..4c3fba8ac3b6 100755 --- a/joern-cli/src/universal/schema-extender/test.sh +++ b/joern-cli/src/universal/schema-extender/test.sh @@ -9,6 +9,6 @@ set -x #verbose on # we should now be able to use our new `EXAMPLE_NODE` node mkdir -p scripts echo 'assert(nodes.ExampleNode.Label == "EXAMPLE_NODE") -assert(nodes.ExampleNode.PropertyNames.all.contains("EXAMPLE_PROPERTY"))' > scripts/SchemaExtenderTest.sc +assert(nodes.ExampleNode.PropertyNames.ExampleProperty == "EXAMPLE_PROPERTY")' > scripts/SchemaExtenderTest.sc ./joern --script scripts/SchemaExtenderTest.sc diff --git a/joern-install.sh b/joern-install.sh index 487177c9cd0d..e3734fec69a9 100755 --- a/joern-install.sh +++ b/joern-install.sh @@ -190,7 +190,6 @@ else sudo ln -sf "$JOERN_INSTALL_DIR"/joern-cli/joern-export "$JOERN_LINK_DIR" || true sudo ln -sf "$JOERN_INSTALL_DIR"/joern-cli/joern-flow "$JOERN_LINK_DIR" || true sudo ln -sf "$JOERN_INSTALL_DIR"/joern-cli/joern-scan "$JOERN_LINK_DIR" || true - sudo ln -sf "$JOERN_INSTALL_DIR"/joern-cli/joern-stats "$JOERN_LINK_DIR" || true sudo ln -sf "$JOERN_INSTALL_DIR"/joern-cli/joern-slice "$JOERN_LINK_DIR" || true fi fi diff --git a/macros/build.sbt b/macros/build.sbt index d54e400c5e88..51f63dbd8026 100644 --- a/macros/build.sbt +++ b/macros/build.sbt @@ -3,8 +3,9 @@ name := "macros" dependsOn(Projects.semanticcpg % Test) libraryDependencies ++= Seq( - "io.shiftleft" %% "codepropertygraph" % Versions.cpg, - "org.scalatest" %% "scalatest" % Versions.scalatest % Test + "io.shiftleft" %% "codepropertygraph" % Versions.cpg, + "net.oneandone.reflections8" % "reflections8" % "0.11.7", + "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) enablePlugins(JavaAppPackaging) diff --git a/macros/src/main/scala/io/joern/console/QueryDatabase.scala b/macros/src/main/scala/io/joern/console/QueryDatabase.scala index 31ca139b5c13..54933ef15478 100644 --- a/macros/src/main/scala/io/joern/console/QueryDatabase.scala +++ b/macros/src/main/scala/io/joern/console/QueryDatabase.scala @@ -5,7 +5,7 @@ import org.reflections8.util.{ClasspathHelper, ConfigurationBuilder} import java.lang.reflect.{Method, Parameter} import scala.annotation.unused -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* trait QueryBundle diff --git a/macros/src/test/scala/io/joern/console/QueryDatabaseTests.scala b/macros/src/test/scala/io/joern/console/QueryDatabaseTests.scala index 6f026090c225..3dd84bca6ed0 100644 --- a/macros/src/test/scala/io/joern/console/QueryDatabaseTests.scala +++ b/macros/src/test/scala/io/joern/console/QueryDatabaseTests.scala @@ -1,7 +1,7 @@ package io.joern.console import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should import org.scalatest.wordspec.AnyWordSpec diff --git a/macros/src/test/scala/io/joern/macros/QueryMacroTests.scala b/macros/src/test/scala/io/joern/macros/QueryMacroTests.scala index 5476f9eaa48b..254a7cd3ae9d 100644 --- a/macros/src/test/scala/io/joern/macros/QueryMacroTests.scala +++ b/macros/src/test/scala/io/joern/macros/QueryMacroTests.scala @@ -4,8 +4,8 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import io.joern.macros.QueryMacros.withStrRep -import io.joern.console._ -import io.shiftleft.semanticcpg.language._ +import io.joern.console.* +import io.shiftleft.semanticcpg.language.* class QueryMacroTests extends AnyWordSpec with Matchers { "Query macros" should { diff --git a/project/Versions.scala b/project/Versions.scala index 9e75b78f4c4f..b7bdd352a8de 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -4,37 +4,41 @@ object Versions { // Dont upgrade antlr to 4.10 or above since those versions require java 11 or higher which // causes problems upstreams. val antlr = "4.7.2" - val cask = "0.9.2" + val cask = "0.9.5" // 0.9.5 is actually the latest release, not 0.10.2 ¯\_(ツ)_/¯ - check the cask git commits... val catsCore = "2.12.0" val catsEffect = "3.5.4" val cfr = "0.152" val commonsCompress = "1.26.2" + val commonsExec = "1.4.0" val commonsIo = "2.16.0" val commonsLang = "3.14.0" val commonsText = "1.12.0" - val eclipseCdt = "8.4.0.202401242025" - val eclipseCore = "3.20.100" - val eclipseText = "3.14.0" - val ghidra = "11.0_PUBLIC_20231222-2" + val eclipseCdt = "8.5.0.202410191453+3" + val eclipseCore = "3.22.0" + val eclipseText = "3.14.200" + val ghidra = "11.2.1_PUBLIC_20241105-7" val gradleTooling = "8.3" val jacksonDatabind = "2.17.0" - val javaParser = "3.25.9" + val javaParser = "3.26.2" + val jlhttp = "3.1" + val jRuby = "9.4.9.0" val json4s = "4.0.7" val lombok = "1.18.32" val mavenArcheologist = "0.0.10" + val phpParser = "4.15.10" val pPrint = "0.8.1" val reflection = "0.10.2" val requests = "0.8.0" val scalaParallel = "1.0.4" val scalaParserCombinators = "2.4.0" - val scalaReplPP = "0.1.87" + val scalaReplPP = "0.3.9" val scalatest = "3.2.18" val scopt = "4.1.0" val semverParser = "0.0.6" - val soot = "4.5.0" + val soot = "4.6.0" val slf4j = "2.0.7" val log4j = "2.20.0" - val upickle = "3.3.1" + val upickle = "4.0.2" val zeroTurnaround = "1.17" // Shared with `projects/meta-build.sbt`, which needs to be updated there directly @@ -44,6 +48,7 @@ object Versions { val typeSafeConfig = "1.4.3" val versionSort = "1.0.11" val zip4j = "2.11.5" + val asm = "9.7.1" private def parseVersion(key: String): String = { val versionRegexp = s""".*val $key[ ]+=[ ]?"(.*?)"""".r diff --git a/project/build.properties b/project/build.properties index 081fdbbc7625..73df629ac1a7 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.0 +sbt.version=1.10.7 diff --git a/project/plugins.sbt b/project/plugins.sbt index f06ef61201fe..2f11a41bfef7 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,5 +1,5 @@ addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") -addSbtPlugin("com.github.sbt" % "sbt-native-packager" % "1.10.0") +addSbtPlugin("com.github.sbt" % "sbt-native-packager" % "1.10.4") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") -addSbtPlugin("io.shiftleft" % "sbt-ci-release-early" % "2.0.19") -addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.0.1") +addSbtPlugin("io.shiftleft" % "sbt-ci-release-early" % "2.0.48") +addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.1.0") diff --git a/querydb/src/main/scala/io/joern/dumpq/Main.scala b/querydb/src/main/scala/io/joern/dumpq/Main.scala index b2b012e5da07..32f6c8cf3098 100644 --- a/querydb/src/main/scala/io/joern/dumpq/Main.scala +++ b/querydb/src/main/scala/io/joern/dumpq/Main.scala @@ -2,7 +2,7 @@ package io.joern.dumpq import io.joern.console.{DefaultArgumentProvider, QueryDatabase} import io.joern.dataflowengineoss.queryengine.{EngineConfig, EngineContext} -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import org.json4s.{Formats, NoTypeHints} import org.json4s.native.Serialization @@ -13,7 +13,7 @@ object Main { } def dumpQueries(): Unit = { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val formats: AnyRef & Formats = Serialization.formats(NoTypeHints) val queryDb = new QueryDatabase(new JoernDefaultArgumentProvider(0)) diff --git a/querydb/src/main/scala/io/joern/scanners/Crew.scala b/querydb/src/main/scala/io/joern/scanners/Crew.scala index efb1662ddab1..058538313322 100644 --- a/querydb/src/main/scala/io/joern/scanners/Crew.scala +++ b/querydb/src/main/scala/io/joern/scanners/Crew.scala @@ -8,5 +8,6 @@ object Crew { val claudiu = "@ursachec" val malte = "@maltek" val dave = "@DavidBakerEffendi" + val SJ1iu = "@piggyctf" } diff --git a/querydb/src/main/scala/io/joern/scanners/android/ArbitraryFileWrites.scala b/querydb/src/main/scala/io/joern/scanners/android/ArbitraryFileWrites.scala index 97e78f177457..a897527bd642 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/ArbitraryFileWrites.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/ArbitraryFileWrites.scala @@ -1,15 +1,15 @@ package io.joern.scanners.android -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* object ArbitraryFileWrites extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve // todo: improve accuracy, might lead to high number of false positives diff --git a/querydb/src/main/scala/io/joern/scanners/android/ExternalStorage.scala b/querydb/src/main/scala/io/joern/scanners/android/ExternalStorage.scala index 7bbb3199d27f..857f341291c8 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/ExternalStorage.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/ExternalStorage.scala @@ -3,13 +3,13 @@ package io.joern.scanners.android import io.joern.console.* import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import io.joern.macros.QueryMacros.* import io.joern.scanners.* import io.shiftleft.semanticcpg.language.* object ExternalStorage extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve // TODO: improve matching around external storage permissions diff --git a/querydb/src/main/scala/io/joern/scanners/android/Intents.scala b/querydb/src/main/scala/io/joern/scanners/android/Intents.scala index 2b8de93a286e..20f506048be1 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/Intents.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/Intents.scala @@ -1,15 +1,15 @@ package io.joern.scanners.android -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* object Intents extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve @q diff --git a/querydb/src/main/scala/io/joern/scanners/android/JavaScriptInterface.scala b/querydb/src/main/scala/io/joern/scanners/android/JavaScriptInterface.scala index 41834d9b4ae0..0c01df9a7c4d 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/JavaScriptInterface.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/JavaScriptInterface.scala @@ -3,13 +3,13 @@ package io.joern.scanners.android import io.joern.console.* import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import io.joern.macros.QueryMacros.* import io.joern.scanners.* import io.shiftleft.semanticcpg.language.* object JavaScriptInterface extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve // TODO: take into account network_security_config diff --git a/querydb/src/main/scala/io/joern/scanners/android/RootDetection.scala b/querydb/src/main/scala/io/joern/scanners/android/RootDetection.scala index 4586fcccb094..d586a0426670 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/RootDetection.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/RootDetection.scala @@ -1,15 +1,15 @@ package io.joern.scanners.android -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* object RootDetection extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve @q diff --git a/querydb/src/main/scala/io/joern/scanners/android/UnprotectedAppParts.scala b/querydb/src/main/scala/io/joern/scanners/android/UnprotectedAppParts.scala index c4b1be1f536f..b2d9e501050d 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/UnprotectedAppParts.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/UnprotectedAppParts.scala @@ -3,13 +3,13 @@ package io.joern.scanners.android import io.joern.console.* import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import io.joern.macros.QueryMacros.* import io.joern.scanners.* import io.shiftleft.semanticcpg.language.* object UnprotectedAppParts extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve @q diff --git a/querydb/src/main/scala/io/joern/scanners/android/UnsafeReflection.scala b/querydb/src/main/scala/io/joern/scanners/android/UnsafeReflection.scala index a9c3ec5c3d5b..e4f76b6f74b2 100644 --- a/querydb/src/main/scala/io/joern/scanners/android/UnsafeReflection.scala +++ b/querydb/src/main/scala/io/joern/scanners/android/UnsafeReflection.scala @@ -1,14 +1,14 @@ package io.joern.scanners.android -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* object UnsafeReflection extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve // todo: support `build.gradle.kts` diff --git a/querydb/src/main/scala/io/joern/scanners/c/CopyLoops.scala b/querydb/src/main/scala/io/joern/scanners/c/CopyLoops.scala index 7b00fb2a9c5a..4495ece033a7 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/CopyLoops.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/CopyLoops.scala @@ -1,9 +1,9 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ +import io.joern.scanners.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* object CopyLoops extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/CredentialDrop.scala b/querydb/src/main/scala/io/joern/scanners/c/CredentialDrop.scala index 966cc3d79022..c8d9cb07c5db 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/CredentialDrop.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/CredentialDrop.scala @@ -1,9 +1,9 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.joern.console._ -import io.shiftleft.semanticcpg.language._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.joern.console.* +import io.shiftleft.semanticcpg.language.* +import io.joern.macros.QueryMacros.* object CredentialDrop extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/DangerousFunctions.scala b/querydb/src/main/scala/io/joern/scanners/c/DangerousFunctions.scala index 9ef4d463a9ee..cd54e9a76625 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/DangerousFunctions.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/DangerousFunctions.scala @@ -1,9 +1,9 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.joern.console._ -import io.shiftleft.semanticcpg.language._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.joern.console.* +import io.shiftleft.semanticcpg.language.* +import io.joern.macros.QueryMacros.* object DangerousFunctions extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/HeapBasedOverflow.scala b/querydb/src/main/scala/io/joern/scanners/c/HeapBasedOverflow.scala index e7e58ca2933c..14207f3bebd7 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/HeapBasedOverflow.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/HeapBasedOverflow.scala @@ -1,16 +1,16 @@ package io.joern.scanners.c -import io.joern.scanners._ +import io.joern.scanners.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ -import io.joern.console._ -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* +import io.joern.console.* +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* object HeapBasedOverflow extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve /** Find calls to malloc where the first argument contains an arithmetic expression, the allocated buffer flows into diff --git a/querydb/src/main/scala/io/joern/scanners/c/IntegerTruncations.scala b/querydb/src/main/scala/io/joern/scanners/c/IntegerTruncations.scala index 3f5bfe035b2b..405bd68b7ebb 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/IntegerTruncations.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/IntegerTruncations.scala @@ -1,9 +1,9 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.shiftleft.semanticcpg.language._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* object IntegerTruncations extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/Metrics.scala b/querydb/src/main/scala/io/joern/scanners/c/Metrics.scala index 7cc521612c18..96acd023b041 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/Metrics.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/Metrics.scala @@ -1,9 +1,9 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.joern.console._ -import io.shiftleft.semanticcpg.language._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.joern.console.* +import io.shiftleft.semanticcpg.language.* +import io.joern.macros.QueryMacros.* object Metrics extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/MissingLengthCheck.scala b/querydb/src/main/scala/io/joern/scanners/c/MissingLengthCheck.scala index ed6b7d17dd92..f82855e5866f 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/MissingLengthCheck.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/MissingLengthCheck.scala @@ -1,14 +1,14 @@ package io.joern.scanners.c import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.shiftleft.codepropertygraph.generated.nodes import io.joern.dataflowengineoss.queryengine.EngineContext -import io.shiftleft.semanticcpg.language._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language.operatorextension._ -import QueryLangExtensions._ +import io.shiftleft.semanticcpg.language.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.operatorextension.* +import QueryLangExtensions.* object MissingLengthCheck extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/NullTermination.scala b/querydb/src/main/scala/io/joern/scanners/c/NullTermination.scala index 99bbc63a5b3d..36da5781b0ec 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/NullTermination.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/NullTermination.scala @@ -1,16 +1,16 @@ package io.joern.scanners.c import io.joern.scanners.{Crew, QueryTags} -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ -import io.joern.console._ +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.macros.QueryMacros._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* object NullTermination extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve @q diff --git a/querydb/src/main/scala/io/joern/scanners/c/RetvalChecks.scala b/querydb/src/main/scala/io/joern/scanners/c/RetvalChecks.scala index ca0c9e3eaf7d..479ea0c060f8 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/RetvalChecks.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/RetvalChecks.scala @@ -1,10 +1,10 @@ package io.joern.scanners.c import io.joern.scanners.{Crew, QueryTags} -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import QueryLangExtensions._ +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import QueryLangExtensions.* object RetvalChecks extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/SignedLeftShift.scala b/querydb/src/main/scala/io/joern/scanners/c/SignedLeftShift.scala index 232da036f815..7730e7236342 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/SignedLeftShift.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/SignedLeftShift.scala @@ -1,10 +1,10 @@ package io.joern.scanners.c -import io.joern.scanners._ +import io.joern.scanners.* import io.shiftleft.codepropertygraph.generated.Operators -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* object SignedLeftShift extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/SocketApi.scala b/querydb/src/main/scala/io/joern/scanners/c/SocketApi.scala index 9fe1d3901450..db5cc5171738 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/SocketApi.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/SocketApi.scala @@ -1,11 +1,11 @@ package io.joern.scanners.c import io.joern.scanners.{Crew, QueryTags} -import io.joern.console._ +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import QueryLangExtensions._ +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import QueryLangExtensions.* object SocketApi extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/c/UseAfterFree.scala b/querydb/src/main/scala/io/joern/scanners/c/UseAfterFree.scala index 034b279affbd..75e137b70057 100644 --- a/querydb/src/main/scala/io/joern/scanners/c/UseAfterFree.scala +++ b/querydb/src/main/scala/io/joern/scanners/c/UseAfterFree.scala @@ -236,7 +236,7 @@ object UseAfterFree extends QueryBundle { | if (cond) { | free(x); | if (cond2) - | return x; // not post-dominated by free call + | return x; // doesn't post-dominate the free call | x = NULL; | } | return x; diff --git a/querydb/src/main/scala/io/joern/scanners/ghidra/DangerousFunctions.scala b/querydb/src/main/scala/io/joern/scanners/ghidra/DangerousFunctions.scala index 16216093a627..0ccebc7be7ef 100644 --- a/querydb/src/main/scala/io/joern/scanners/ghidra/DangerousFunctions.scala +++ b/querydb/src/main/scala/io/joern/scanners/ghidra/DangerousFunctions.scala @@ -1,9 +1,9 @@ package io.joern.scanners.ghidra -import io.joern.scanners._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ +import io.joern.scanners.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* object DangerousFunctions extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/ghidra/UserInputIntoDangerousFunctions.scala b/querydb/src/main/scala/io/joern/scanners/ghidra/UserInputIntoDangerousFunctions.scala index 21310c543181..3c47d454a90c 100644 --- a/querydb/src/main/scala/io/joern/scanners/ghidra/UserInputIntoDangerousFunctions.scala +++ b/querydb/src/main/scala/io/joern/scanners/ghidra/UserInputIntoDangerousFunctions.scala @@ -1,10 +1,10 @@ package io.joern.scanners.ghidra -import io.joern.scanners._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ -import io.joern.dataflowengineoss.language._ +import io.joern.scanners.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext object UserInputIntoDangerousFunctions extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/java/CrossSiteScripting.scala b/querydb/src/main/scala/io/joern/scanners/java/CrossSiteScripting.scala index e7265c367e12..66cfe2fe7f28 100644 --- a/querydb/src/main/scala/io/joern/scanners/java/CrossSiteScripting.scala +++ b/querydb/src/main/scala/io/joern/scanners/java/CrossSiteScripting.scala @@ -1,10 +1,10 @@ package io.joern.scanners.java -import io.joern.scanners._ -import io.shiftleft.semanticcpg.language._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.joern.dataflowengineoss.language._ +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext object CrossSiteScripting extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/java/CryptographyMisuse.scala b/querydb/src/main/scala/io/joern/scanners/java/CryptographyMisuse.scala index 1e43586657f0..5b96efcabfbb 100644 --- a/querydb/src/main/scala/io/joern/scanners/java/CryptographyMisuse.scala +++ b/querydb/src/main/scala/io/joern/scanners/java/CryptographyMisuse.scala @@ -1,10 +1,10 @@ package io.joern.scanners.java -import io.joern.scanners._ -import io.shiftleft.semanticcpg.language._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.joern.dataflowengineoss.language._ +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext /** @see diff --git a/querydb/src/main/scala/io/joern/scanners/java/DangerousFunctions.scala b/querydb/src/main/scala/io/joern/scanners/java/DangerousFunctions.scala index cdfabc11b03f..ccee98ccbf3a 100644 --- a/querydb/src/main/scala/io/joern/scanners/java/DangerousFunctions.scala +++ b/querydb/src/main/scala/io/joern/scanners/java/DangerousFunctions.scala @@ -1,9 +1,9 @@ package io.joern.scanners.java -import io.joern.scanners._ -import io.shiftleft.semanticcpg.language._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* object DangerousFunctions extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/java/SQLInjection.scala b/querydb/src/main/scala/io/joern/scanners/java/SQLInjection.scala index 67aebe735261..fa987cfd722e 100644 --- a/querydb/src/main/scala/io/joern/scanners/java/SQLInjection.scala +++ b/querydb/src/main/scala/io/joern/scanners/java/SQLInjection.scala @@ -1,10 +1,10 @@ package io.joern.scanners.java -import io.joern.scanners._ -import io.shiftleft.semanticcpg.language._ -import io.joern.console._ -import io.joern.macros.QueryMacros._ -import io.joern.dataflowengineoss.language._ +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext // The queries are tied to springframework diff --git a/querydb/src/main/scala/io/joern/scanners/java/SpringExpressionLanguageInjection.scala b/querydb/src/main/scala/io/joern/scanners/java/SpringExpressionLanguageInjection.scala new file mode 100644 index 000000000000..2749f1434f34 --- /dev/null +++ b/querydb/src/main/scala/io/joern/scanners/java/SpringExpressionLanguageInjection.scala @@ -0,0 +1,88 @@ +package io.joern.scanners.java; + +import io.joern.scanners.* +import io.shiftleft.semanticcpg.language.* +import io.joern.console.* +import io.joern.macros.QueryMacros.* +import io.joern.dataflowengineoss.language.* +import io.joern.dataflowengineoss.queryengine.EngineContext + +object SpringExpressionLanguageInjection extends QueryBundle { + + implicit val resolver: ICallResolver = NoResolve + + @q + def SpelInject()(implicit context: EngineContext): Query = + Query.make( + name = "Spring-Expression-Language-Injection", + author = Crew.SJ1iu, + title = + "Spring-Expression-Language-Injection: The value is taken from user input and passed to ExpressionParser!!", + description = """ + | In a SpEL injection, if user-controlled input is directly parsed and evaluated as a SpEL expression without validation, attackers can execute arbitrary expressions. + |""".stripMargin, + score = 8, + withStrRep({ cpg => + + def source = + cpg.parameter.where(_.annotation.name("RequestParam")).where(_.name("expression")) + + def sink = + cpg.call.name("parseExpression").argument.order(2).l + + sink.reachableBy(source).l + + }), + tags = List(QueryTags.badfn, QueryTags.default), + multiFileCodeExamples = MultiFileCodeExamples( + positive = List( + List( + CodeSnippet( + """ + |import org.springframework.expression.ExpressionParser; + |import org.springframework.expression.spel.standard.SpelExpressionParser; + |import org.springframework.web.bind.annotation.GetMapping; + |import org.springframework.web.bind.annotation.RequestParam; + |import org.springframework.web.bind.annotation.RestController; + |@RestController + |public class SpelInjectionController { + |private final ExpressionParser parser = new SpelExpressionParser(); + + |@GetMapping("/evaluate") + |public String evaluateExpression(@RequestParam String expression) { + |// This line is vulnerable to SpEL injection as it directly evaluates user input + |Object result = parser.parseExpression(expression).getValue(); + |return "Evaluation result: " + result; + |} + |} + |""".stripMargin, + "Positive.kt" + ) + ) + ), + negative = List( + List( + CodeSnippet( + """ + |import org.springframework.expression.ExpressionParser; + |import org.springframework.expression.spel.standard.SpelExpressionParser; + |import org.springframework.web.bind.annotation.GetMapping; + |import org.springframework.web.bind.annotation.RequestParam; + |import org.springframework.web.bind.annotation.RestController; + |@RestController + |public class SpelInjectionController { + |private final ExpressionParser parser = new SpelExpressionParser(); + + |@GetMapping("/evaluate") + |public String evaluateExpression(@RequestParam String expression) { + |return "NOT VULNERABLE"; + |} + |} + |""".stripMargin, + "Negative.kt" + ) + ) + ) + ) + ) +} diff --git a/querydb/src/main/scala/io/joern/scanners/kotlin/NetworkCommunication.scala b/querydb/src/main/scala/io/joern/scanners/kotlin/NetworkCommunication.scala index 2f31e6c2800d..9669bbfbc6ef 100644 --- a/querydb/src/main/scala/io/joern/scanners/kotlin/NetworkCommunication.scala +++ b/querydb/src/main/scala/io/joern/scanners/kotlin/NetworkCommunication.scala @@ -1,16 +1,16 @@ package io.joern.scanners.kotlin -import io.joern.scanners._ -import io.joern.console._ +import io.joern.console.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.dataflowengineoss.language._ -import io.joern.macros.QueryMacros._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.macros.QueryMacros.* +import io.joern.scanners.* import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* object NetworkCommunication extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve // todo: improve by including trust managers created via `object` expressions @@ -37,7 +37,9 @@ object NetworkCommunication extends QueryBundle { def nopTrustManagersAllocs = cpg.method.fullNameExact(Operators.alloc).callIn.typeFullNameExact(nopTrustManagerFullNames*) def sslCtxInitCalls = cpg.method - .fullNameExact("javax.net.ssl.SSLContext.init:void(kotlin.Array,kotlin.Array,java.security.SecureRandom)") + .fullNameExact( + "javax.net.ssl.SSLContext.init:void(javax.net.ssl.KeyManager[],javax.net.ssl.TrustManager[],java.security.SecureRandom)" + ) .callIn sslCtxInitCalls.filter { call => call.argument(2).reachableBy(nopTrustManagersAllocs).nonEmpty diff --git a/querydb/src/main/scala/io/joern/scanners/kotlin/PathTraversals.scala b/querydb/src/main/scala/io/joern/scanners/kotlin/PathTraversals.scala index fe64c33c1203..97a08137614d 100644 --- a/querydb/src/main/scala/io/joern/scanners/kotlin/PathTraversals.scala +++ b/querydb/src/main/scala/io/joern/scanners/kotlin/PathTraversals.scala @@ -1,15 +1,15 @@ package io.joern.scanners.kotlin -import io.joern.scanners._ -import io.joern.console._ +import io.joern.scanners.* +import io.joern.console.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics -import io.joern.dataflowengineoss.language._ -import io.joern.macros.QueryMacros._ -import io.shiftleft.semanticcpg.language._ +import io.joern.dataflowengineoss.semanticsloader.NoSemantics +import io.joern.dataflowengineoss.language.* +import io.joern.macros.QueryMacros.* +import io.shiftleft.semanticcpg.language.* object PathTraversals extends QueryBundle { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) implicit val resolver: ICallResolver = NoResolve @q diff --git a/querydb/src/main/scala/io/joern/scanners/php/SQLInjection.scala b/querydb/src/main/scala/io/joern/scanners/php/SQLInjection.scala index 0829f18301c0..7fbaa232ad3d 100644 --- a/querydb/src/main/scala/io/joern/scanners/php/SQLInjection.scala +++ b/querydb/src/main/scala/io/joern/scanners/php/SQLInjection.scala @@ -1,12 +1,12 @@ package io.joern.scanners.php -import io.joern.console._ -import io.joern.dataflowengineoss.language._ +import io.joern.console.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.macros.QueryMacros._ -import io.joern.scanners._ +import io.joern.macros.QueryMacros.* +import io.joern.scanners.* import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* object SQLInjection extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/php/ShellExec.scala b/querydb/src/main/scala/io/joern/scanners/php/ShellExec.scala index 923538116f84..cbf2854c4c56 100644 --- a/querydb/src/main/scala/io/joern/scanners/php/ShellExec.scala +++ b/querydb/src/main/scala/io/joern/scanners/php/ShellExec.scala @@ -1,12 +1,12 @@ package io.joern.scanners.php -import io.joern.console._ -import io.joern.dataflowengineoss.language._ +import io.joern.console.* +import io.joern.dataflowengineoss.language.* import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.macros.QueryMacros._ -import io.joern.scanners._ +import io.joern.macros.QueryMacros.* +import io.joern.scanners.* import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* object ShellExec extends QueryBundle { diff --git a/querydb/src/main/scala/io/joern/scanners/php/TwigTemplateInjection.scala b/querydb/src/main/scala/io/joern/scanners/php/TwigTemplateInjection.scala new file mode 100644 index 000000000000..747f81f2deb8 --- /dev/null +++ b/querydb/src/main/scala/io/joern/scanners/php/TwigTemplateInjection.scala @@ -0,0 +1,122 @@ +package io.joern.scanners.php + +import io.joern.console.* +import io.joern.dataflowengineoss.language.* +import io.joern.dataflowengineoss.queryengine.EngineContext +import io.joern.macros.QueryMacros.* +import io.joern.scanners.* +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.semanticcpg.language.* + +object TwigTemplateInjection extends QueryBundle { + + implicit val resolver: ICallResolver = NoResolve + + @q + def TwigTemplateInjection()(implicit context: EngineContext): Query = + Query.make( + name = "twig-template-injection", + author = Crew.SJ1iu, + title = "Twig-Template-Injection: A parameter controlled by the user is rendered within a Twig template.", + description = """ + |An attacker controlled parameter is used in an twig template. + | + |This doesn't necessarily indicate a Twig template injection, but if the input is not sanitized and the escape settings are disabled in the application, it could potentially lead to a template injection vulnerability. + |""".stripMargin, + score = 5, + withStrRep({ cpg => + + def source = + cpg.call.name(Operators.assignment).argument.code("(?i).*request.*") + + def sink = + cpg.call.name("createTemplate").methodFullName("(?i).*twig.*").argument + + sink.reachableBy(source).iterator + + }), + tags = List(QueryTags.remoteCodeExecution, QueryTags.default), + multiFileCodeExamples = MultiFileCodeExamples( + positive = List( + List( + CodeSnippet( + """ + | false, // Disable caching for development + | 'debug' => true, // Enable debugging + | 'autoescape' => false // Disabling auto-escaping can lead to template injection vulnerabilities, potentially allowing command execution if the input is not properly sanitized + |]); + | + |// Get the 'name' parameter from the request, some other dummy parameters are provided but not in use. The rule will only detect the vulnerable parameter "name". The rule is granular enough to detect other requests, such as those using Symfony\Component\HttpFoundation\Request. + |$name = $_REQUEST['name'] ?? 'Guest'; + |$name2 = $_REQUEST['name2'] ?? 'Guest'; + |$name3 = $_REQUEST['name3'] ?? 'Guest'; + |$name4 = $_REQUEST['name4'] ?? 'Guest'; + |$name5 = $_REQUEST['name5'] ?? 'Guest'; + | + |// Render a dynamic template using createTemplate. This is the sink. + |$template = $twig->createTemplate("Hello, {$name}! Welcome to Twig dynamic templates."); + | + |// Render + |echo $template->render(['name' => $name]); + |""".stripMargin, + "Positive.kt" + ) + ) + ), + negative = List( + List( + CodeSnippet( + """ + | false, // Disable caching for development + | 'debug' => true, // Enable debugging + | 'autoescape' => false + |]); + | + |// This time a custom function named "createTemplate" is defined which has no template injection issues. It's simply echo the user's input. + |function createTemplate($templateString) { + | echo $templateString; + |} + | + |// Get the 'name' parameter from the request, some other dummy parameters are provided but not in use. + |$name = $_REQUEST['name'] ?? 'Guest'; + |$name2 = $_REQUEST['name2'] ?? 'Guest'; + |$name3 = $_REQUEST['name3'] ?? 'Guest'; + |$name4 = $_REQUEST['name4'] ?? 'Guest'; + |$name5 = $_REQUEST['name5'] ?? 'Guest'; + | + |// All Twig functions below are commented out, but the custom "createTemplate" function is called. No Twig template injection occurred this time, and the rule will not report an issue even if the function names "createTemplate" is invoked. + |createTemplate($name); + | + |// Render a regular template + |// echo $twig->render('template.twig', ['title' => 'Twig Setup', 'name' => $name]); + | + |// Render a dynamic template using createTemplate + |// $template = $twig->createTemplate("Hello, {$name}! Welcome to Twig dynamic templates."); + | + |// echo $template->render(['name' => $name]); + |""".stripMargin, + "Negative.kt" + ) + ) + ) + ) + ) +} diff --git a/querydb/src/test/scala/io/joern/scanners/android/RootDetectionTests.scala b/querydb/src/test/scala/io/joern/scanners/android/RootDetectionTests.scala index c8280bdabd37..e4a0cebea4e0 100644 --- a/querydb/src/test/scala/io/joern/scanners/android/RootDetectionTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/android/RootDetectionTests.scala @@ -1,11 +1,11 @@ package io.joern.scanners.android import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.Semantics +import io.joern.dataflowengineoss.semanticsloader.NoSemantics import io.joern.suites.KotlinQueryTestSuite class RootDetectionTests extends KotlinQueryTestSuite(RootDetection) { - implicit val engineContext: EngineContext = EngineContext(Semantics.empty) + implicit val engineContext: EngineContext = EngineContext(NoSemantics) "the `rootDetectionViaFileChecks` query" when { "should match on all multi-file positive examples" in { diff --git a/querydb/src/test/scala/io/joern/scanners/android/UnprotectedAppPartsTests.scala b/querydb/src/test/scala/io/joern/scanners/android/UnprotectedAppPartsTests.scala index 81d807064d2f..c013457a12fa 100644 --- a/querydb/src/test/scala/io/joern/scanners/android/UnprotectedAppPartsTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/android/UnprotectedAppPartsTests.scala @@ -1,9 +1,9 @@ package io.joern.scanners.android -import io.joern.console.scan._ +import io.joern.console.scan.* import io.shiftleft.codepropertygraph.generated.nodes.CfgNode import io.joern.suites.KotlinQueryTestSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class UnprotectedAppPartsTests extends KotlinQueryTestSuite(UnprotectedAppParts) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/CopyLoopTests.scala b/querydb/src/test/scala/io/joern/scanners/c/CopyLoopTests.scala index 4c20bd6d4791..82520686a738 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/CopyLoopTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/CopyLoopTests.scala @@ -2,8 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.semanticcpg.language._ -import io.joern.console.scan._ +import io.shiftleft.semanticcpg.language.* +import io.joern.console.scan.* class CopyLoopTests extends CQueryTestSuite(CopyLoops) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/HeapBasedOverflowTests.scala b/querydb/src/test/scala/io/joern/scanners/c/HeapBasedOverflowTests.scala index b583fc4ddf61..e3bedcba6b1e 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/HeapBasedOverflowTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/HeapBasedOverflowTests.scala @@ -2,7 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.console.scan._ +import io.joern.console.scan.* +import io.shiftleft.semanticcpg.language.* class HeapBasedOverflowTests extends CQueryTestSuite(HeapBasedOverflow) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/IntegerTruncationsTests.scala b/querydb/src/test/scala/io/joern/scanners/c/IntegerTruncationsTests.scala index b4fac1fd3913..9e3b5d177c99 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/IntegerTruncationsTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/IntegerTruncationsTests.scala @@ -2,8 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.semanticcpg.language._ -import io.joern.console.scan._ +import io.shiftleft.semanticcpg.language.* +import io.joern.console.scan.* class IntegerTruncationsTests extends CQueryTestSuite(IntegerTruncations) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/MetricsTests.scala b/querydb/src/test/scala/io/joern/scanners/c/MetricsTests.scala index 91d0dd61d4eb..a741cea32590 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/MetricsTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/MetricsTests.scala @@ -2,7 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.console.scan._ +import io.joern.console.scan.* +import io.shiftleft.semanticcpg.language.* class MetricsTests extends CQueryTestSuite(Metrics) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/NullTerminationTests.scala b/querydb/src/test/scala/io/joern/scanners/c/NullTerminationTests.scala index 139c01859aec..327618fd92f8 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/NullTerminationTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/NullTerminationTests.scala @@ -2,8 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.semanticcpg.language._ -import io.joern.console.scan._ +import io.shiftleft.semanticcpg.language.* +import io.joern.console.scan.* class NullTerminationTests extends CQueryTestSuite(NullTermination) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/QueryWithReachableBy.scala b/querydb/src/test/scala/io/joern/scanners/c/QueryWithReachableBy.scala index 6a6d21b311df..6eff6c970a4f 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/QueryWithReachableBy.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/QueryWithReachableBy.scala @@ -1,10 +1,10 @@ package io.joern.scanners.c -import io.joern.scanners._ -import io.joern.console._ -import io.joern.dataflowengineoss.language._ -import io.shiftleft.semanticcpg.language._ -import io.joern.macros.QueryMacros._ +import io.joern.scanners.* +import io.joern.console.* +import io.joern.dataflowengineoss.language.* +import io.shiftleft.semanticcpg.language.* +import io.joern.macros.QueryMacros.* import io.joern.dataflowengineoss.queryengine.EngineContext /** Just to make sure that we support reachableBy queries, which did not work before diff --git a/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreePostUsage.scala b/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreePostUsage.scala index 568bb515c515..ed2611ddacbe 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreePostUsage.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreePostUsage.scala @@ -2,8 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.console.scan._ -import io.shiftleft.semanticcpg.language._ +import io.joern.console.scan.* +import io.shiftleft.semanticcpg.language.* class UseAfterFreePostUsage extends CQueryTestSuite(UseAfterFree) { diff --git a/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreeReturnTests.scala b/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreeReturnTests.scala index 3ae07cc654f9..190be8f9bd57 100644 --- a/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreeReturnTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/c/UseAfterFreeReturnTests.scala @@ -2,8 +2,8 @@ package io.joern.scanners.c import io.joern.suites.CQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.console.scan._ -import io.shiftleft.semanticcpg.language._ +import io.joern.console.scan.* +import io.shiftleft.semanticcpg.language.* class UseAfterFreeReturnTests extends CQueryTestSuite(UseAfterFree) { diff --git a/querydb/src/test/scala/io/joern/scanners/kotlin/NetworkProtocolsTests.scala b/querydb/src/test/scala/io/joern/scanners/kotlin/NetworkProtocolsTests.scala index 3c0bbfa01176..ae41b9d9d210 100644 --- a/querydb/src/test/scala/io/joern/scanners/kotlin/NetworkProtocolsTests.scala +++ b/querydb/src/test/scala/io/joern/scanners/kotlin/NetworkProtocolsTests.scala @@ -1,9 +1,9 @@ package io.joern.scanners.kotlin -import io.joern.console.scan._ +import io.joern.console.scan.* import io.joern.suites.KotlinQueryTestSuite import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class NetworkProtocolsTests extends KotlinQueryTestSuite(NetworkProtocols) { "should find calls relevant to insecure network protocol usage" in { diff --git a/querydb/src/test/scala/io/joern/suites/AllBundlesTestSuite.scala b/querydb/src/test/scala/io/joern/suites/AllBundlesTestSuite.scala index 9fc370a5284f..f35909393c30 100644 --- a/querydb/src/test/scala/io/joern/suites/AllBundlesTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/AllBundlesTestSuite.scala @@ -2,7 +2,7 @@ package io.joern.suites import io.joern.console.QueryDatabase import org.scalatest.wordspec.AnyWordSpec -import org.scalatest.matchers.should.Matchers._ +import org.scalatest.matchers.should.Matchers.* class AllBundlesTestSuite extends AnyWordSpec { val argumentProvider = new QDBArgumentProvider(3) diff --git a/querydb/src/test/scala/io/joern/suites/AndroidQueryTestSuite.scala b/querydb/src/test/scala/io/joern/suites/AndroidQueryTestSuite.scala index 4be9b05268b7..f40e1d3f7529 100644 --- a/querydb/src/test/scala/io/joern/suites/AndroidQueryTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/AndroidQueryTestSuite.scala @@ -1,12 +1,12 @@ package io.joern.suites -import io.joern.console.scan._ +import io.joern.console.scan.* import io.joern.console.{CodeSnippet, Query, QueryBundle} import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.joern.util.QueryUtil import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.ConfigFile -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AndroidQueryTestSuite[QB <: QueryBundle](val queryBundle: QB) extends KotlinCode2CpgFixture(withOssDataflow = true, withDefaultJars = true) { diff --git a/querydb/src/test/scala/io/joern/suites/CQueryTestSuite.scala b/querydb/src/test/scala/io/joern/suites/CQueryTestSuite.scala index 93e5d618c08c..97e29744fd44 100644 --- a/querydb/src/test/scala/io/joern/suites/CQueryTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/CQueryTestSuite.scala @@ -2,12 +2,12 @@ package io.joern.suites import io.joern.util.QueryUtil import io.shiftleft.codepropertygraph.generated.nodes -import io.joern.console.scan._ +import io.joern.console.scan.* import io.joern.console.QueryBundle import io.joern.console.Query import io.joern.c2cpg.testfixtures.DataFlowCodeToCpgSuite import io.joern.x2cpg.testfixtures.TestCpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class CQueryTestSuite[QB <: QueryBundle](val queryBundle: QB) extends DataFlowCodeToCpgSuite { diff --git a/querydb/src/test/scala/io/joern/suites/GhidraQueryTestSuite.scala b/querydb/src/test/scala/io/joern/suites/GhidraQueryTestSuite.scala index 1a6d1cc5eadc..81ea14d8af67 100644 --- a/querydb/src/test/scala/io/joern/suites/GhidraQueryTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/GhidraQueryTestSuite.scala @@ -1,12 +1,12 @@ package io.joern.suites import io.joern.console.QueryBundle -import io.joern.console.scan._ +import io.joern.console.scan.* import io.joern.ghidra2cpg.fixtures.DataFlowBinToCpgSuite import io.joern.util.QueryUtil import io.shiftleft.codepropertygraph.generated.nodes import io.joern.console.Query -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.ProjectRoot class GhidraQueryTestSuite[QB <: QueryBundle](val queryBundle: QB) extends DataFlowBinToCpgSuite { diff --git a/querydb/src/test/scala/io/joern/suites/JavaQueryTestSuite.scala b/querydb/src/test/scala/io/joern/suites/JavaQueryTestSuite.scala index 3407c2837ee8..00711ed73b9b 100644 --- a/querydb/src/test/scala/io/joern/suites/JavaQueryTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/JavaQueryTestSuite.scala @@ -1,12 +1,13 @@ package io.joern.suites -import io.joern.console.scan._ +import io.joern.console.scan.* import io.joern.console.{CodeSnippet, Query, QueryBundle} import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.util.QueryUtil import io.joern.x2cpg.testfixtures.TestCpg import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal, Method, StoredNode} +import io.shiftleft.semanticcpg.language.* class JavaQueryTestSuite[QB <: QueryBundle](val queryBundle: QB) extends JavaSrcCode2CpgFixture(withOssDataflow = true) { diff --git a/querydb/src/test/scala/io/joern/suites/KotlinQueryTestSuite.scala b/querydb/src/test/scala/io/joern/suites/KotlinQueryTestSuite.scala index 01c102334903..469c1517b919 100644 --- a/querydb/src/test/scala/io/joern/suites/KotlinQueryTestSuite.scala +++ b/querydb/src/test/scala/io/joern/suites/KotlinQueryTestSuite.scala @@ -6,7 +6,8 @@ import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture import io.joern.x2cpg.testfixtures.TestCpg import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} -import io.joern.console.scan._ +import io.shiftleft.semanticcpg.language.* +import io.joern.console.scan.* import io.shiftleft.utils.ProjectRoot class KotlinQueryTestSuite[QB <: QueryBundle](val queryBundle: QB) diff --git a/querydb/src/test/scala/io/joern/suites/QDBArgumentProvider.scala b/querydb/src/test/scala/io/joern/suites/QDBArgumentProvider.scala index 0f8946a59d9e..7c19ed88d228 100644 --- a/querydb/src/test/scala/io/joern/suites/QDBArgumentProvider.scala +++ b/querydb/src/test/scala/io/joern/suites/QDBArgumentProvider.scala @@ -2,7 +2,7 @@ package io.joern.suites import io.joern.console.DefaultArgumentProvider import io.joern.dataflowengineoss.queryengine.EngineContext -import io.joern.dataflowengineoss.semanticsloader.{Parser, Semantics} +import io.joern.dataflowengineoss.semanticsloader.{FullNameSemanticsParser, Semantics} import java.nio.file.Paths diff --git a/semanticcpg/build.sbt b/semanticcpg/build.sbt index ea2590fc0a0d..8c2c66a7e0b5 100644 --- a/semanticcpg/build.sbt +++ b/semanticcpg/build.sbt @@ -1,11 +1,12 @@ name := "semanticcpg" libraryDependencies ++= Seq( - "io.shiftleft" %% "codepropertygraph" % Versions.cpg, - "com.michaelpollmeier" %% "scala-repl-pp" % Versions.scalaReplPP, - "org.json4s" %% "json4s-native" % Versions.json4s, - "org.apache.commons" % "commons-text" % Versions.commonsText, - "org.scalatest" %% "scalatest" % Versions.scalatest % Test + "io.shiftleft" %% "codepropertygraph" % Versions.cpg, + "com.michaelpollmeier" %% "scala-repl-pp" % Versions.scalaReplPP, + "org.json4s" %% "json4s-native" % Versions.json4s, + "org.scala-lang.modules" %% "scala-xml" % "2.2.0", + "org.apache.commons" % "commons-text" % Versions.commonsText, + "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) Compile / doc / scalacOptions ++= Seq("-doc-title", "semanticcpg apidocs", "-doc-version", version.value) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala index d10dd2108cf0..7a62e226f138 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala @@ -1,9 +1,9 @@ package io.shiftleft.semanticcpg -import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import io.shiftleft.codepropertygraph.generated.Properties import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* object Overlays { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala index f1af0f8d8cd3..8ce5442b22d0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala @@ -1,6 +1,7 @@ package io.shiftleft.semanticcpg.accesspath -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* trait TrackedBase case class TrackedNamedVariable(name: String) extends TrackedBase @@ -28,6 +29,19 @@ case class TrackedMethod(method: MethodRef) extends TrackedMethodOrTypeRef { } case class TrackedTypeRef(typeRef: TypeRef) extends TrackedMethodOrTypeRef { override def code: String = typeRef.code + + override def equals(obj: Any): Boolean = { + obj match { + case TrackedTypeRef(otherTypeRef) => + typeRef.evalTypeOut.head.equals(otherTypeRef.evalTypeOut.head) + case _ => + false + } + } + + override def hashCode(): Int = { + typeRef.evalTypeOut.head.hashCode() + } } case class TrackedAlias(argIndex: Int) extends TrackedBase { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala index f68afc41afcb..52a649f52ffd 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala @@ -3,7 +3,7 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, MethodParameterOut} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class AstGenerator { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala index d1aecf60bb0e..dfee4db69a23 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala @@ -3,7 +3,7 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Method, StoredNode} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala index 2507bcf9e3d6..9223d7d98218 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala @@ -4,7 +4,7 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.Edge -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* class CdgGenerator extends CfgGenerator { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala index 45999fbe46e8..067ce1302792 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala @@ -1,10 +1,9 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ -import overflowdb.Node +import io.shiftleft.semanticcpg.language.* class CfgGenerator { @@ -44,12 +43,12 @@ class CfgGenerator { protected def expand(v: StoredNode): Iterator[Edge] = v._cfgOut.map(node => Edge(v, node, edgeType = edgeType)) - private def isConditionInControlStructure(v: Node): Boolean = v match { + private def isConditionInControlStructure(v: StoredNode): Boolean = v match { case id: Identifier => id.astParent.isControlStructure case _ => false } - private def cfgNodeShouldBeDisplayed(v: Node): Boolean = + private def cfgNodeShouldBeDisplayed(v: StoredNode): Boolean = isConditionInControlStructure(v) || !(v.isInstanceOf[Literal] || v.isInstanceOf[Identifier] || diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala index 6ee27691b66f..5f885935d8f1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala @@ -1,19 +1,19 @@ package io.shiftleft.semanticcpg.dotgenerator -import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.Properties import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.utils.MemberAccess +import org.apache.commons.lang3.StringUtils import org.apache.commons.text.StringEscapeUtils -import java.util.Optional import scala.collection.immutable.HashMap import scala.collection.mutable import scala.language.postfixOps object DotSerializer { - private val charLimit = 50 + private val CharLimit = 50 case class Graph( vertices: List[StoredNode], @@ -39,6 +39,8 @@ object DotSerializer { case Some(r) => namedGraphBegin(r) case None => defaultGraphBegin() } + + sb.append(s"""node [shape="rect"]; \n""") val nodeStrings = graph.vertices.map(nodeToDot) val edgeStrings = graph.edges.map(e => edgeToDot(e, withEdgeTypes)) val subgraphStrings = graph.subgraph.zipWithIndex.map { case ((subgraph, nodes), idx) => @@ -63,45 +65,45 @@ object DotSerializer { sb.append(s"""digraph "$name" { \n""") } - private def limit(str: String): String = if (str.length > charLimit) { - s"${str.take(charLimit - 3)}..." - } else { - str - } + private def limit(str: String): String = StringUtils.abbreviate(str, CharLimit) private def stringRepr(vertex: StoredNode): String = { - val maybeLineNo: Optional[AnyRef] = vertex.propertyOption(PropertyNames.LINE_NUMBER) - StringEscapeUtils.escapeHtml4(vertex match { - case call: Call => (call.name, limit(call.code)).toString - case contrl: ControlStructure => (contrl.label, contrl.controlStructureType, contrl.code).toString - case expr: Expression => (expr.label, limit(expr.code), limit(toCfgNode(expr).code)).toString - case method: Method => (method.label, method.name).toString - case ret: MethodReturn => (ret.label, ret.typeFullName).toString - case param: MethodParameterIn => ("PARAM", param.code).toString - case local: Local => (local.label, s"${local.code}: ${local.typeFullName}").toString - case target: JumpTarget => (target.label, target.name).toString - case modifier: Modifier => (modifier.label, modifier.modifierType).toString() - case annoAssign: AnnotationParameterAssign => (annoAssign.label, annoAssign.code).toString() - case annoParam: AnnotationParameter => (annoParam.label, annoParam.code).toString() - case typ: Type => (typ.label, typ.name).toString() - case typeDecl: TypeDecl => (typeDecl.label, typeDecl.name).toString() - case member: Member => (member.label, member.name).toString() - case _ => "" - }) + (if (maybeLineNo.isPresent) s"${maybeLineNo.get()}" else "") + val lineOpt = vertex.property(Properties.LineNumber).map(_.toString) + val attrList = (vertex match { + case call: Call => List(call.name, limit(call.code)) + case ctrl: ControlStructure => List(ctrl.label, ctrl.controlStructureType, ctrl.code) + case expr: Expression => List(expr.label, limit(expr.code), limit(toCfgNode(expr).code)) + case method: Method => List(method.label, method.name) + case ret: MethodReturn => List(ret.label, ret.typeFullName) + case param: MethodParameterIn => List("PARAM", param.code) + case local: Local => List(local.label, s"${local.code}: ${local.typeFullName}") + case target: JumpTarget => List(target.label, target.name) + case modifier: Modifier => List(modifier.label, modifier.modifierType) + case annoAssign: AnnotationParameterAssign => List(annoAssign.label, annoAssign.code) + case annoParam: AnnotationParameter => List(annoParam.label, annoParam.code) + case typ: Type => List(typ.label, typ.name) + case typeDecl: TypeDecl => List(typeDecl.label, typeDecl.name) + case member: Member => List(member.label, member.name) + case _ => List.empty + }).map(l => StringEscapeUtils.escapeHtml4(StringUtils.normalizeSpace(l))) + + (lineOpt match { + case Some(line) => s"${attrList.head}, $line" :: attrList.tail + case None => attrList + }).distinct.mkString("
") } private def toCfgNode(node: StoredNode): CfgNode = { node match { - case node: Identifier => node.parentExpression.get - case node: MethodRef => node.parentExpression.get - case node: Literal => node.parentExpression.get - case node: MethodParameterIn => node.method - case node: MethodParameterOut => node.method.methodReturn - case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => - node.parentExpression.get - case node: CallRepr => node - case node: MethodReturn => node - case node: Expression => node + case node: Identifier => node.parentExpression.get + case node: MethodRef => node.parentExpression.get + case node: Literal => node.parentExpression.get + case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => node.parentExpression.get + case node: MethodParameterOut => node.method.methodReturn + case node: MethodParameterIn => node.method + case node: CallRepr => node + case node: MethodReturn => node + case node: Expression => node } } @@ -119,7 +121,7 @@ object DotSerializer { s""" "${edge.src.id}" -> "${edge.dst.id}" """ + labelStr } - def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = { + private def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = { val escapedName = StringEscapeUtils.escapeHtml4(subgraph) val childString = children.map { c => s" \"${c.id()}\";" }.mkString("\n") s""" subgraph cluster_$idx { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala index 25891f6e7c48..2ec4454d901f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala @@ -3,7 +3,7 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{StoredNode, Type, TypeDecl} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala index 86944d6da608..6292879b1b0d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala @@ -1,8 +1,9 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.{Operators, Properties, PropertyNames} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.accesspath._ +import io.shiftleft.codepropertygraph.generated.{Operators, Properties} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.accesspath.* +import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.IteratorHasAsScala @@ -42,8 +43,9 @@ object AccessPathHandling { .collect { case node: Literal => ConstantAccess(node.code) case node: Identifier => ConstantAccess(node.name) - case other if other.propertyOption(PropertyNames.NAME).isPresent => - logger.warn(s"unexpected/deprecated node encountered: $other with properties: ${other.propertiesMap()}") + case other if other.propertyOption(Properties.Name).isDefined => + val properties = other.propertiesMap + logger.warn(s"unexpected/deprecated node encountered: $other with properties: $properties") ConstantAccess(other.property(Properties.Name)) } .getOrElse(VariableAccess) :: tail diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala index 0a00dd9c5586..008e87131dbb 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala @@ -1,8 +1,8 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} -import overflowdb.traversal._ import scala.annotation.tailrec @@ -64,7 +64,7 @@ object LocationCreator { @tailrec private def findVertex(node: StoredNode, instanceCheck: StoredNode => Boolean): Option[StoredNode] = - node._astIn.nextOption() match { + node._astIn.iterator.nextOption() match { case Some(head) if instanceCheck(head) => Some(head) case Some(head) => findVertex(head, instanceCheck) case None => None diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala index 2a4384e9c7e1..dbcce1aa4807 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.NewNode -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder trait HasStoreMethod { def store()(implicit diffBuilder: DiffGraphBuilder): Unit diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala index 9248f9e72b1d..44b52836c0ac 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewTagNodePair, StoredNode} -import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder class NewTagNodePairTraversal(traversal: Iterator[NewTagNodePair]) extends HasStoreMethod { override def store()(implicit diffGraph: DiffGraphBuilder): Unit = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala index 6c7c4024c26b..f0712ce1f87d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala @@ -1,29 +1,8 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.nodes.{ - Call, - Identifier, - Literal, - Local, - Method, - MethodParameterIn, - MethodParameterOut, - MethodRef, - MethodReturn, - StoredNode -} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.nodemethods.{ - CallMethods, - IdentifierMethods, - LiteralMethods, - LocalMethods, - MethodMethods, - MethodParameterInMethods, - MethodParameterOutMethods, - MethodRefMethods, - MethodReturnMethods -} +import io.shiftleft.semanticcpg.language.nodemethods.* trait NodeExtensionFinder { def apply(n: StoredNode): Option[NodeExtension] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala index 71971d654828..337714e94255 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala @@ -1,18 +1,16 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.semanticcpg.codedumper.CodeDumper -import overflowdb.Node -import overflowdb.traversal._ -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} /** Steps for all node types * * This is the base class for all steps defined on */ -@help.Traversal(elementType = classOf[StoredNode]) +@Traversal(elementType = classOf[StoredNode]) class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) extends AnyVal { @Doc( @@ -23,15 +21,16 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten |the file node that represents that source file. |""" ) - def file: Iterator[File] = - traversal - .choose(_.label) { - case NodeTypes.NAMESPACE => _.in(EdgeTypes.REF).out(EdgeTypes.SOURCE_FILE) - case NodeTypes.COMMENT => _.in(EdgeTypes.AST).hasLabel(NodeTypes.FILE) - case _ => - _.repeat(_.coalesce(_.out(EdgeTypes.SOURCE_FILE), _.in(EdgeTypes.AST)))(_.until(_.hasLabel(NodeTypes.FILE))) - } - .cast[File] + def file: Iterator[File] = { + traversal.flatMap { + case namespace: Namespace => + namespace.refIn.sourceFileOut + case comment: Comment => + comment.astIn + case node => + Iterator(node).repeat(_.coalesce(_._sourceFileOut, _._astIn))(_.until(_.hasLabel(File.Label))).cast[File] + } + } @Doc( info = "Location, including filename and line number", @@ -84,11 +83,6 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten }.l } - /* follow the incoming edges of the given type as long as possible */ - protected def walkIn(edgeType: String): Iterator[Node] = - traversal - .repeat(_.in(edgeType))(_.until(_.in(edgeType).countTrav.filter(_ == 0))) - @Doc( info = "Tag node with `tagName`", longInfo = """ diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala index daa8f1b82e23..4f45f3f3bea3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala @@ -1,318 +1,92 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{NodeTypes, Properties} -import overflowdb._ -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc -import overflowdb.traversal.{InitialTraversal, TraversalSource} +import io.shiftleft.codepropertygraph.generated.help.{Doc, TraversalSource} +import io.shiftleft.semanticcpg.language.* -import scala.jdk.CollectionConverters.IteratorHasAsScala +/** Starting point for a new traversal, e.g. + * - `cpg.method`, `cpg.call` etc. - these are generated by the flatgraph codegenerator and automatically inherited + * - `cpg.method.name` + */ +@TraversalSource +class NodeTypeStarters(cpg: Cpg) { -@help.TraversalSource -class NodeTypeStarters(cpg: Cpg) extends TraversalSource(cpg.graph) { - - /** Traverse to all nodes. - */ - @Doc(info = "All nodes of the graph") - override def all: Traversal[StoredNode] = - cpg.graph.nodes.asScala.cast[StoredNode] - - /** Traverse to all annotations - */ - def annotation: Traversal[Annotation] = - InitialTraversal.from[Annotation](cpg.graph, NodeTypes.ANNOTATION) - - /** Traverse to all arguments passed to methods - */ + /** Traverse to all arguments passed to methods */ @Doc(info = "All arguments (actual parameters)") - def argument: Traversal[Expression] = - call.argument + def argument: Iterator[Expression] = + cpg.call.argument - /** Shorthand for `cpg.argument.code(code)` - */ - def argument(code: String): Traversal[Expression] = - argument.code(code) + /** Shorthand for `cpg.argument.code(code)` */ + def argument(code: String): Iterator[Expression] = + cpg.argument.code(code) @Doc(info = "All breaks (`ControlStructure` nodes)") - def break: Traversal[ControlStructure] = - controlStructure.isBreak - - /** Traverse to all call sites - */ - @Doc(info = "All call sites") - def call: Traversal[Call] = - InitialTraversal.from[Call](cpg.graph, NodeTypes.CALL) - - /** Shorthand for `cpg.call.name(name)` - */ - def call(name: String): Traversal[Call] = - call.name(name) - - /** Traverse to all comments in source-based CPGs. - */ - @Doc(info = "All comments in source-based CPGs") - def comment: Traversal[Comment] = - InitialTraversal.from[Comment](cpg.graph, NodeTypes.COMMENT) - - /** Shorthand for `cpg.comment.code(code)` - */ - def comment(code: String): Traversal[Comment] = - comment.has(Properties.Code -> code) - - /** Traverse to all config files - */ - @Doc(info = "All config files") - def configFile: Traversal[ConfigFile] = - InitialTraversal.from[ConfigFile](cpg.graph, NodeTypes.CONFIG_FILE) - - /** Shorthand for `cpg.configFile.name(name)` - */ - def configFile(name: String): Traversal[ConfigFile] = - configFile.name(name) - - /** Traverse to all dependencies - */ - @Doc(info = "All dependencies") - def dependency: Traversal[Dependency] = - InitialTraversal.from[Dependency](cpg.graph, NodeTypes.DEPENDENCY) - - /** Shorthand for `cpg.dependency.name(name)` - */ - def dependency(name: String): Traversal[Dependency] = - dependency.name(name) - - @Doc(info = "All control structures (source-based frontends)") - def controlStructure: Traversal[ControlStructure] = - InitialTraversal.from[ControlStructure](cpg.graph, NodeTypes.CONTROL_STRUCTURE) + def break: Iterator[ControlStructure] = + cpg.controlStructure.isBreak @Doc(info = "All continues (`ControlStructure` nodes)") - def continue: Traversal[ControlStructure] = - controlStructure.isContinue + def continue: Iterator[ControlStructure] = + cpg.controlStructure.isContinue @Doc(info = "All do blocks (`ControlStructure` nodes)") - def doBlock: Traversal[ControlStructure] = - controlStructure.isDo + def doBlock: Iterator[ControlStructure] = + cpg.controlStructure.isDo @Doc(info = "All else blocks (`ControlStructure` nodes)") - def elseBlock: Traversal[ControlStructure] = - controlStructure.isElse + def elseBlock: Iterator[ControlStructure] = + cpg.controlStructure.isElse @Doc(info = "All throws (`ControlStructure` nodes)") - def throws: Traversal[ControlStructure] = - controlStructure.isThrow - - /** Traverse to all source files - */ - @Doc(info = "All source files") - def file: Traversal[File] = - InitialTraversal.from[File](cpg.graph, NodeTypes.FILE) - - /** Shorthand for `cpg.file.name(name)` - */ - def file(name: String): Traversal[File] = - file.name(name) + def throws: Iterator[ControlStructure] = + cpg.controlStructure.isThrow @Doc(info = "All for blocks (`ControlStructure` nodes)") - def forBlock: Traversal[ControlStructure] = - controlStructure.isFor + def forBlock: Iterator[ControlStructure] = + cpg.controlStructure.isFor @Doc(info = "All gotos (`ControlStructure` nodes)") - def goto: Traversal[ControlStructure] = - controlStructure.isGoto - - /** Traverse to all identifiers, e.g., occurrences of local variables or class members in method bodies. - */ - @Doc(info = "All identifier usages") - def identifier: Traversal[Identifier] = - InitialTraversal.from[Identifier](cpg.graph, NodeTypes.IDENTIFIER) - - /** Shorthand for `cpg.identifier.name(name)` - */ - def identifier(name: String): Traversal[Identifier] = - identifier.name(name) + def goto: Iterator[ControlStructure] = + cpg.controlStructure.isGoto @Doc(info = "All if blocks (`ControlStructure` nodes)") - def ifBlock: Traversal[ControlStructure] = - controlStructure.isIf - - /** Traverse to all jump targets - */ - @Doc(info = "All jump targets, i.e., labels") - def jumpTarget: Traversal[JumpTarget] = - InitialTraversal.from[JumpTarget](cpg.graph, NodeTypes.JUMP_TARGET) - - /** Traverse to all local variable declarations - */ - @Doc(info = "All local variables") - def local: Traversal[Local] = - InitialTraversal.from[Local](cpg.graph, NodeTypes.LOCAL) - - /** Shorthand for `cpg.local.name` - */ - def local(name: String): Traversal[Local] = - local.name(name) - - /** Traverse to all literals (constant strings and numbers provided directly in the code). - */ - @Doc(info = "All literals, e.g., numbers or strings") - def literal: Traversal[Literal] = - InitialTraversal.from[Literal](cpg.graph, NodeTypes.LITERAL) - - /** Shorthand for `cpg.literal.code(code)` - */ - def literal(code: String): Traversal[Literal] = - literal.code(code) - - /** Traverse to all methods - */ - @Doc(info = "All methods") - def method: Traversal[Method] = - InitialTraversal.from[Method](cpg.graph, NodeTypes.METHOD) - - /** Shorthand for `cpg.method.name(name)` - */ - @Doc(info = "All methods with a name that matches the given pattern") - def method(namePattern: String): Traversal[Method] = - method.name(namePattern) - - /** Traverse to all formal return parameters - */ - @Doc(info = "All formal return parameters") - def methodReturn: Traversal[MethodReturn] = - InitialTraversal.from[MethodReturn](cpg.graph, NodeTypes.METHOD_RETURN) - - /** Traverse to all class members - */ - @Doc(info = "All members of complex types (e.g., classes/structures)") - def member: Traversal[Member] = - InitialTraversal.from[Member](cpg.graph, NodeTypes.MEMBER) - - /** Shorthand for `cpg.member.name(name)` - */ - def member(name: String): Traversal[Member] = - member.name(name) - - /** Traverse to all meta data entries - */ - @Doc(info = "Meta data blocks for graph") - def metaData: Traversal[MetaData] = - InitialTraversal.from[MetaData](cpg.graph, NodeTypes.META_DATA) - - /** Traverse to all method references - */ - @Doc(info = "All method references") - def methodRef: Traversal[MethodRef] = - InitialTraversal.from[MethodRef](cpg.graph, NodeTypes.METHOD_REF) - - /** Shorthand for `cpg.methodRef.filter(_.referencedMethod.name(name))` - */ - def methodRef(name: String): Traversal[MethodRef] = - methodRef.where(_.referencedMethod.name(name)) - - /** Traverse to all namespaces, e.g., packages in Java. - */ - @Doc(info = "All namespaces") - def namespace: Traversal[Namespace] = - InitialTraversal.from[Namespace](cpg.graph, NodeTypes.NAMESPACE) - - /** Shorthand for `cpg.namespace.name(name)` - */ - def namespace(name: String): Traversal[Namespace] = - namespace.name(name) - - /** Traverse to all namespace blocks, e.g., packages in Java. - */ - def namespaceBlock: Traversal[NamespaceBlock] = - InitialTraversal.from[NamespaceBlock](cpg.graph, NodeTypes.NAMESPACE_BLOCK) - - /** Shorthand for `cpg.namespaceBlock.name(name)` - */ - def namespaceBlock(name: String): Traversal[NamespaceBlock] = - namespaceBlock.name(name) - - /** Traverse to all input parameters - */ + def ifBlock: Iterator[ControlStructure] = + cpg.controlStructure.isIf + + /** Shorthand for `cpg.methodRef.where(_.referencedMethod.name(name))` + * + * Note re API design: this step was supposed to be called `methodRef(name: String)`, but due to limitations in + * Scala's implicit resolution (and the setup of our implicit steps) we have to disambiguate it from `.methodRef` by + * name. + * + * More precisely: Scala's implicit resolution reports 'ambiguous implicits' if two methods with the same name but + * different parameters are defined in two different (implicitly reachable) classes. The `.methodRef` step is defined + * in `generated.CpgNodeStarter`. This step (filter by name) doesn't get generated by the codegen because it's more + * complex than the other 'filter by primary key' starter steps. + */ + def methodRefWithName(name: String): Iterator[MethodRef] = + cpg.methodRef.where(_.referencedMethod.name(name)) + + /** Traverse to all input parameters */ @Doc(info = "All parameters") - def parameter: Traversal[MethodParameterIn] = - InitialTraversal.from[MethodParameterIn](cpg.graph, NodeTypes.METHOD_PARAMETER_IN) + def parameter: Iterator[MethodParameterIn] = + cpg.methodParameterIn - /** Shorthand for `cpg.parameter.name(name)` - */ - def parameter(name: String): Traversal[MethodParameterIn] = + /** Shorthand for `cpg.parameter.name(name)` */ + def parameter(name: String): Iterator[MethodParameterIn] = parameter.name(name) - /** Traverse to all return expressions - */ - @Doc(info = "All actual return parameters") - def ret: Traversal[Return] = - InitialTraversal.from[Return](cpg.graph, NodeTypes.RETURN) - - /** Shorthand for `returns.code(code)` - */ - def ret(code: String): Traversal[Return] = - ret.code(code) - - @Doc(info = "All imports") - def imports: Traversal[Import] = - InitialTraversal.from[Import](cpg.graph, NodeTypes.IMPORT) - @Doc(info = "All switch blocks (`ControlStructure` nodes)") - def switchBlock: Traversal[ControlStructure] = - controlStructure.isSwitch + def switchBlock: Iterator[ControlStructure] = + cpg.controlStructure.isSwitch @Doc(info = "All try blocks (`ControlStructure` nodes)") - def tryBlock: Traversal[ControlStructure] = - controlStructure.isTry - - /** Traverse to all types, e.g., Set - */ - @Doc(info = "All used types") - def typ: Traversal[Type] = - InitialTraversal.from[Type](cpg.graph, NodeTypes.TYPE) - - /** Shorthand for `cpg.typ.name(name)` - */ - @Doc(info = "All used types with given name") - def typ(name: String): Traversal[Type] = - typ.name(name) - - /** Traverse to all declarations, e.g., Set - */ - @Doc(info = "All declarations of types") - def typeDecl: Traversal[TypeDecl] = - InitialTraversal.from[TypeDecl](cpg.graph, NodeTypes.TYPE_DECL) - - /** Shorthand for cpg.typeDecl.name(name) - */ - def typeDecl(name: String): Traversal[TypeDecl] = - typeDecl.name(name) - - /** Traverse to all tags - */ - @Doc(info = "All tags") - def tag: Traversal[Tag] = - InitialTraversal.from[Tag](cpg.graph, NodeTypes.TAG) - - @Doc(info = "All tags with given name") - def tag(name: String): Traversal[Tag] = - tag.name(name) - - /** Traverse to all template DOM nodes - */ - @Doc(info = "All template DOM nodes") - def templateDom: Traversal[TemplateDom] = - InitialTraversal.from[TemplateDom](cpg.graph, NodeTypes.TEMPLATE_DOM) - - /** Traverse to all type references - */ - @Doc(info = "All type references") - def typeRef: Traversal[TypeRef] = - InitialTraversal.from[TypeRef](cpg.graph, NodeTypes.TYPE_REF) + def tryBlock: Iterator[ControlStructure] = + cpg.controlStructure.isTry @Doc(info = "All while blocks (`ControlStructure` nodes)") - def whileBlock: Traversal[ControlStructure] = - controlStructure.isWhile + def whileBlock: Iterator[ControlStructure] = + cpg.controlStructure.isWhile } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/SarifExtension.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/SarifExtension.scala new file mode 100644 index 000000000000..4a511d2ff21e --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/SarifExtension.scala @@ -0,0 +1,75 @@ +package io.shiftleft.semanticcpg.language + +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.help.Doc +import io.shiftleft.codepropertygraph.generated.nodes.Finding +import io.shiftleft.semanticcpg.sarif.SarifConfig.SarifVersion +import io.shiftleft.semanticcpg.sarif.SarifSchema.{Sarif, Sarif2_1_0} +import io.shiftleft.semanticcpg.sarif.{SarifConfig, SarifSchema, v2_1_0} +import org.json4s.Formats +import org.json4s.native.Serialization.{write, writePretty} + +import java.net.URI + +/** Converts findings written to the CPG to the SARIF format. + * + * @param traversal + * the findings + * @see + * https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html + */ +class SarifExtension(val traversal: Iterator[Finding]) extends AnyVal { + + @Doc(info = "execute this traversal and convert findings to SARIF format") + def toSarif(implicit config: SarifConfig = SarifConfig()): Sarif = { + + def generateSarif( + results: List[SarifSchema.Result], + reportingDescriptors: List[SarifSchema.ReportingDescriptor], + baseUri: Option[URI] + ): Sarif = { + config.sarifVersion match { + case SarifVersion.V2_1_0 => + val tool = v2_1_0.Schema.ToolComponent( + name = config.toolName, + fullName = config.toolFullName, + organization = config.organization, + semanticVersion = config.semanticVersion, + informationUri = config.toolInformationUri, + rules = reportingDescriptors + ) + val projectBaseUri = Map( + "PROJECT_ROOT" -> v2_1_0.Schema + .ArtifactLocation(uriBaseId = baseUri.map(_.toString).orElse(Option(""))) + ) + val runs = v2_1_0.Schema.Run( + tool = v2_1_0.Schema.Tool(driver = tool), + originalUriBaseIds = projectBaseUri, + results = results + ) :: Nil + Sarif2_1_0(runs = runs) + } + } + + traversal.l match { + case Nil => generateSarif(results = Nil, reportingDescriptors = Nil, baseUri = None) + case findings @ head :: _ => + val baseUri = Cpg(head.graph).metaData.root.headOption.map(java.io.File(_).toURI) + val results = findings.map(config.resultConverter.convertFindingToResult) + val reportingDescriptors = + findings.flatMap(config.resultConverter.convertFindingToReportingDescriptor).distinctBy(_.id) + generateSarif(results, reportingDescriptors, baseUri) + } + + } + + @Doc(info = "execute this traversal and convert findings to SARIF format as JSON") + def toSarifJson(pretty: Boolean = false)(implicit config: SarifConfig = SarifConfig()): String = { + implicit val formats: Formats = org.json4s.DefaultFormats ++ config.customSerializers + + val results = toSarif + if (pretty) writePretty(results) + else write(results) + } + +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala index 0194d4dc571b..e6efabae3f28 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala @@ -1,9 +1,8 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.nodes.NewNode -import overflowdb.Node +import io.shiftleft.codepropertygraph.generated.nodes.{AbstractNode, NewNode, StoredNode} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* /** Typeclass for (pretty) printing an object */ @@ -18,20 +17,20 @@ object Show { override def apply(a: Any): String = a match { case node: NewNode => val label = node.label - val properties = propsToString(node.properties.toList) + val properties = propsToString(node.properties) s"($label): $properties" - case node: Node => + case node: StoredNode => val label = node.label val id = node.id().toString - val properties = propsToString(node.propertiesMap.asScala.toList) + val properties = propsToString(node.properties) s"($label,$id): $properties" case other => other.toString } - private def propsToString(keyValues: List[(String, Any)]): String = { - keyValues + private def propsToString(properties: Map[String, Any]): String = { + properties .filter(_._2.toString.nonEmpty) .sortBy(_._1) .map { case (key, value) => s"$key: $value" } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala index 4aa0fce401d3..e24170db71a0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala @@ -1,9 +1,9 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.nodes.AbstractNode +import io.shiftleft.codepropertygraph.generated.nodes.{AbstractNode, StoredNode} import org.json4s.native.Serialization.{write, writePretty} import org.json4s.{CustomSerializer, Extraction, Formats} -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import replpp.Colors import replpp.Operators.* @@ -14,6 +14,7 @@ import scala.jdk.CollectionConverters.* /** Base class for our DSL These are the base steps available in all steps of the query language. There are no * constraints on the element types, unlike e.g. [[NodeSteps]] */ +@Traversal(elementType = classOf[AnyRef]) class Steps[A](val traversal: Iterator[A]) extends AnyVal { /** Execute the traversal and convert it to a mutable buffer @@ -81,14 +82,19 @@ class Steps[A](val traversal: Iterator[A]) extends AnyVal { object Steps { private lazy val nodeSerializer = new CustomSerializer[AbstractNode](implicit format => ( - { case _ => ??? }, - { case node: (AbstractNode & Product) => - val elementMap = (0 until node.productArity).map { i => + { case _ => ??? }, // deserializer not required for now + { case node: AbstractNode => + val elementMap = Map.newBuilder[String, Any] + (0 until node.productArity).foreach { i => val label = node.productElementName(i) val element = node.productElement(i) - label -> element - }.toMap + ("_label" -> node.label) - Extraction.decompose(elementMap) + elementMap.addOne(label -> element) + } + elementMap.addOne("_label" -> node.label) + if (node.isInstanceOf[StoredNode]) { + elementMap.addOne("_id" -> node.asInstanceOf[StoredNode].id()) + } + Extraction.decompose(elementMap.result()) } ) ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala index 936cf1439df9..c2a8d6e5efc9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala @@ -1,9 +1,10 @@ package io.shiftleft.semanticcpg.language.android import io.joern.semanticcpg.utils.SecureXmlParsing -import io.shiftleft.codepropertygraph.generated.nodes +import io.shiftleft.codepropertygraph.generated.nodes.ConfigFile +import io.shiftleft.semanticcpg.language.* -class ConfigFileTraversal(val traversal: Iterator[nodes.ConfigFile]) extends AnyVal { +class ConfigFileTraversal(val traversal: Iterator[ConfigFile]) extends AnyVal { def usesCleartextTraffic = traversal .filter(_.name.endsWith(Constants.androidManifestXml)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala index 6b627ae47e06..29f08c4af74e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala @@ -1,9 +1,11 @@ package io.shiftleft.semanticcpg.language.callgraphextension -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Method]) class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { /** Intended for internal use! Traverse to direct and transitive callers of the method. diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala index aa10a624c79f..80f3622f27ae 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.dotextension import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.semanticcpg.dotgenerator.DotAstGenerator -import overflowdb.traversal.* +import io.shiftleft.semanticcpg.language.* class AstNodeDot[NodeType <: AstNode](val traversal: Iterator[NodeType]) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala index 107b19037495..84d1bdf4647d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.dotextension import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.semanticcpg.dotgenerator.{DotCdgGenerator, DotCfgGenerator} -import overflowdb.traversal.* +import io.shiftleft.semanticcpg.language.* class CfgNodeDot(val traversal: Iterator[Method]) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/importresolver/ResolvedImportAsTagTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/importresolver/ResolvedImportAsTagTraversal.scala index 4400d389dcd9..1e6b4d6dead8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/importresolver/ResolvedImportAsTagTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/importresolver/ResolvedImportAsTagTraversal.scala @@ -1,18 +1,17 @@ package io.shiftleft.semanticcpg.language.importresolver +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Declaration, Member, Tag} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc class ResolvedImportAsTagExt(node: Tag) extends AnyVal { - @Doc(info = "Parses this tag as an EvaluatedImport class") def _toEvaluatedImport: Option[EvaluatedImport] = EvaluatedImport.tagToEvaluatedImport(node) - @Doc(info = "If this tag represents a resolved import, will attempt to find the CPG entities this refers to") def resolvedEntity: Iterator[AstNode] = { - val cpg = Cpg(node.graph()) + val cpg = Cpg(node.graph) node._toEvaluatedImport.iterator .collectAll[ResolvedImport] .flatMap { @@ -25,9 +24,9 @@ class ResolvedImportAsTagExt(node: Tag) extends AnyVal { } .iterator } - } +@Traversal(elementType = classOf[Tag]) class ResolvedImportAsTagTraversal(steps: Iterator[Tag]) extends AnyVal { @Doc(info = "Parses these tags as EvaluatedImport classes") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableAsNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableAsNodeTraversal.scala index c115d84a3883..741f650ee6ba 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableAsNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableAsNodeTraversal.scala @@ -1,12 +1,14 @@ package io.shiftleft.semanticcpg.language.modulevariable +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{Cpg, Operators} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.modulevariable.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Local]) class ModuleVariableAsLocalTraversal(traversal: Iterator[Local]) extends AnyVal { @Doc(info = "Locals representing module variables") @@ -16,6 +18,7 @@ class ModuleVariableAsLocalTraversal(traversal: Iterator[Local]) extends AnyVal } +@Traversal(elementType = classOf[Identifier]) class ModuleVariableAsIdentifierTraversal(traversal: Iterator[Identifier]) extends AnyVal { @Doc(info = "Identifiers representing module variables") @@ -25,12 +28,13 @@ class ModuleVariableAsIdentifierTraversal(traversal: Iterator[Identifier]) exten } +@Traversal(elementType = classOf[FieldIdentifier]) class ModuleVariableAsFieldIdentifierTraversal(traversal: Iterator[FieldIdentifier]) extends AnyVal { @Doc(info = "Field identifiers representing module variables") def moduleVariables: Iterator[OpNodes.ModuleVariable] = { traversal.flatMap { fieldIdentifier => - Cpg(fieldIdentifier.graph()).method + Cpg(fieldIdentifier.graph).method .fullNameExact(fieldIdentifier.inFieldAccess.argument(1).isIdentifier.typeFullName.toSeq*) .isModule .local @@ -40,13 +44,14 @@ class ModuleVariableAsFieldIdentifierTraversal(traversal: Iterator[FieldIdentifi } +@Traversal(elementType = classOf[Member]) class ModuleVariableAsMemberTraversal(traversal: Iterator[Member]) extends AnyVal { @Doc(info = "Members representing module variables") def moduleVariables: Iterator[OpNodes.ModuleVariable] = { val members = traversal.toList lazy val memberNames = members.name.toSeq - members.headOption.map(m => Cpg(m.graph())) match + members.headOption.map(m => Cpg(m.graph)) match case Some(cpg) => cpg.method .fullNameExact(members.typeDecl.fullName.toSeq*) @@ -57,6 +62,7 @@ class ModuleVariableAsMemberTraversal(traversal: Iterator[Member]) extends AnyVa } +@Traversal(elementType = classOf[Expression]) class ModuleVariableAsExpressionTraversal(traversal: Iterator[Expression]) extends AnyVal { @Doc(info = "Expression nodes representing module variables") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableTraversal.scala index 1c58587093c2..23e12cd0fe32 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/ModuleVariableTraversal.scala @@ -1,11 +1,13 @@ package io.shiftleft.semanticcpg.language.modulevariable +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Local]) class ModuleVariableTraversal(traversal: Iterator[OpNodes.ModuleVariable]) extends AnyVal { @Doc(info = "All assignments where the module variables in this traversal are the target across the program") @@ -32,7 +34,7 @@ class ModuleVariableTraversal(traversal: Iterator[OpNodes.ModuleVariable]) exten ) def references: Iterator[Identifier | FieldIdentifier] = { val variables = traversal.toList - variables.headOption.map(node => Cpg(node.graph())) match + variables.headOption.map(node => Cpg(node.graph)) match case Some(cpg) => val modules = cpg.method.isModule.l val variableNames = variables.name.toSet @@ -78,7 +80,7 @@ class ModuleVariableTraversal(traversal: Iterator[OpNodes.ModuleVariable]) exten val variables = traversal.toList lazy val moduleNames = variables.method.isModule.fullName.dedup.toSeq lazy val variableNames = variables.name.toSeq - variables.headOption.map(node => Cpg(node.graph())) match + variables.headOption.map(node => Cpg(node.graph)) match case Some(cpg) => cpg.typeDecl.fullNameExact(moduleNames*).member.nameExact(variableNames*) case None => Iterator.empty } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/NodeTypeStarters.scala index b3462fd1807e..da34b086d2af 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/NodeTypeStarters.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.modulevariable +import io.shiftleft.codepropertygraph.generated.help.{Doc, TraversalSource} import io.shiftleft.codepropertygraph.generated.Cpg -import overflowdb.traversal.help.{Doc, TraversalSource} import io.shiftleft.semanticcpg.language.* @TraversalSource diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/OpNodes.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/OpNodes.scala index dd442178f8af..2c43a8bf567c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/OpNodes.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/OpNodes.scala @@ -1,6 +1,6 @@ package io.shiftleft.semanticcpg.language.modulevariable -import io.shiftleft.codepropertygraph.generated.nodes.{Block, Local, Member, StaticType} +import io.shiftleft.codepropertygraph.generated.nodes.{Local, StaticType} trait ModuleVariableT object OpNodes { @@ -8,7 +8,6 @@ object OpNodes { /** Represents a module-level global variable. This kind of node behaves like both a local variable and a field access * and is common in languages such as Python/JavaScript. */ - type ModuleVariable = Local & StaticType[ModuleVariableT] } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableAsNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableAsNodeMethods.scala index fb80d7092fa2..f70927b3dea2 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableAsNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableAsNodeMethods.scala @@ -2,11 +2,11 @@ package io.shiftleft.semanticcpg.language.modulevariable.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Local, Member} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc class ModuleVariableAsLocalMethods(node: Local) extends AnyVal { - @Doc(info = "If this local is declared on the module-defining method level") + /** If this local is declared on the module-defining method level */ def isModuleVariable: Boolean = node.method.isModule.nonEmpty } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableMethods.scala index ee2a84c2e86d..5a13de925796 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/modulevariable/nodemethods/ModuleVariableMethods.scala @@ -6,19 +6,19 @@ import io.shiftleft.semanticcpg.language.modulevariable.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes as OpExtNodes import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.importresolver.{ResolvedMember, ResolvedTypeDecl} -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc class ModuleVariableMethods(node: OpNodes.ModuleVariable) extends AnyVal { - @Doc(info = "References of this module variable across the codebase, as either identifiers or field identifiers") + /** References of this module variable across the codebase, as either identifiers or field identifiers */ def references: Iterator[Identifier | FieldIdentifier] = node.start.references - @Doc(info = "The module members being referenced in the respective module type declaration") + /** The module members being referenced in the respective module type declaration */ def referencingMembers: Iterator[Member] = { - Cpg(node.graph()).typeDecl.fullNameExact(node.method.fullName.toSeq*).member.nameExact(node.name) + Cpg(node.graph).typeDecl.fullNameExact(node.method.fullName.toSeq*).member.nameExact(node.name) } - @Doc(info = "Returns the assignments where the module variable is the target (LHS)") + /** Returns the assignments where the module variable is the target (LHS) */ def definitions: Iterator[OpExtNodes.Assignment] = node.start.definitions } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala index 14ad5766be64..c23ff1a463d5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala @@ -56,7 +56,7 @@ class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension { val additionalDepth = if (p(node)) { 1 } else { 0 } - val childDepths = node.astChildren.map(_.depth(p)).l + val childDepths = astChildren.map(_.depth(p)).l additionalDepth + (if (childDepths.isEmpty) { 0 } else { @@ -70,7 +70,7 @@ class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension { /** Direct children of node in the AST. Siblings are ordered by their `order` fields */ def astChildren: Iterator[AstNode] = - node._astOut.cast[AstNode].sortBy(_.order).iterator + node._astOut.cast[AstNode].toSeq.sortBy(_.order).iterator /** Siblings of this node in the AST, ordered by their `order` fields */ @@ -108,8 +108,7 @@ class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension { case member: Member => member case node: MethodParameterIn => node.method - case node: MethodParameterOut => - node.method.methodReturn + case node: MethodParameterOut => node.method.methodReturn case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => parentExpansion(node) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala index b124dcbd0c2e..7a1fd1225eab 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala @@ -6,20 +6,30 @@ import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* class CallMethods(val node: Call) extends AnyVal with NodeExtension with HasLocation { + + def isStatic: Boolean = + node.dispatchType == DispatchTypes.STATIC_DISPATCH + + def isDynamic: Boolean = + node.dispatchType == DispatchTypes.DYNAMIC_DISPATCH + + def isInline: Boolean = + node.dispatchType == DispatchTypes.INLINED + def receiver: Iterator[Expression] = - node.receiverOut + node.receiverOut.collectAll[Expression] def arguments(index: Int): Iterator[Expression] = - node._argumentOut - .collect { - case expr: Expression if expr.argumentIndex == index => expr - } + node._argumentOut.collect { + case expr: Expression if expr.argumentIndex == index => expr + } + // TODO define as named step in the schema def argument: Iterator[Expression] = node._argumentOut.collectAll[Expression] def argument(index: Int): Expression = - arguments(index).head + arguments(index).next def argumentOption(index: Int): Option[Expression] = node._argumentOut.collectFirst { @@ -32,7 +42,6 @@ class CallMethods(val node: Call) extends AnyVal with NodeExtension with HasLoca node.astChildren.isBlock.maxByOption(_.order).iterator.expressionDown } - override def location: NewLocation = { + override def location: NewLocation = LocationCreator(node, node.code, node.label, node.lineNumber, node.method) - } } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala index 08ee7d72a7e6..fa14297ad531 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala @@ -11,9 +11,8 @@ class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension { /** Successors in the CFG */ - def cfgNext: Iterator[CfgNode] = { + def cfgNext: Iterator[CfgNode] = Iterator.single(node).cfgNext - } /** Maps each node in the traversal to a traversal returning its n successors. */ @@ -31,9 +30,8 @@ class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension { /** Predecessors in the CFG */ - def cfgPrev: Iterator[CfgNode] = { + def cfgPrev: Iterator[CfgNode] = Iterator.single(node).cfgPrev - } /** Recursively determine all nodes on which this CFG node is control-dependent. */ diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala index e04ff421b06a..134f42ae3397 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{Declaration, Identifier, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, *} +import io.shiftleft.semanticcpg.language.* class IdentifierMethods(val identifier: Identifier) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala index 3bbc9892b5bc..8962c206ff4a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{Literal, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, _} +import io.shiftleft.semanticcpg.language.* class LiteralMethods(val literal: Literal) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala index 9bfa5eac78bc..94abc5458ecd 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala @@ -6,7 +6,7 @@ import io.shiftleft.semanticcpg.language.* class LocalMethods(val local: Local) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { - LocationCreator(local, local.name, local.label, local.lineNumber, local.method.head) + LocationCreator(local, local.name, local.label, local.lineNumber, method.head) } /** The method hosting this local variable diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala index 52a2034ae28f..440d6bf4e082 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala @@ -14,6 +14,13 @@ class MethodMethods(val method: Method) extends AnyVal with NodeExtension with H def local: Iterator[Local] = method._blockViaContainsOut.local + def topLevelExpressions: Iterator[Expression] = + method._astOut + .collectAll[Block] + ._astOut + .not(_.collectAll[Local]) + .cast[Expression] + /** All control structures of this method */ def controlStructure: Iterator[ControlStructure] = @@ -37,14 +44,14 @@ class MethodMethods(val method: Method) extends AnyVal with NodeExtension with H /** List of CFG nodes in reverse post order */ def reversePostOrder: Iterator[CfgNode] = { - def expand(x: CfgNode) = { x.cfgNext.iterator } + def expand(x: CfgNode) = x.cfgNext.iterator NodeOrdering.reverseNodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator } /** List of CFG nodes in post order */ def postOrder: Iterator[CfgNode] = { - def expand(x: CfgNode) = { x.cfgNext.iterator } + def expand(x: CfgNode) = x.cfgNext.iterator NodeOrdering.nodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala index 1f2a0c5420cc..1b39a89145c3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{MethodParameterIn, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} +import io.shiftleft.semanticcpg.language.* class MethodParameterInMethods(val paramIn: MethodParameterIn) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala index 851c5949abbb..29471342f6ca 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{MethodParameterOut, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} +import io.shiftleft.semanticcpg.language.* class MethodParameterOutMethods(val paramOut: MethodParameterOut) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala index fc971fff51c9..03ea62a6dd3a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{MethodRef, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} +import io.shiftleft.semanticcpg.language.* class MethodRefMethods(val methodRef: MethodRef) extends AnyVal with NodeExtension with HasLocation { override def location: NewLocation = { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala index 4ea458735b3b..8d7494c63773 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala @@ -19,5 +19,6 @@ class MethodReturnMethods(val node: MethodReturn) extends AnyVal with NodeExtens callsites.collectAll[Call] } + // TODO define in schema as named step def typ: Iterator[Type] = node.evalTypeOut } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala index 600fe2c1299a..25325e1faccb 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala @@ -1,11 +1,10 @@ package io.shiftleft.semanticcpg.language.nodemethods -import io.shiftleft.codepropertygraph.generated.nodes.{NewLocation, StoredNode} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language._ -import overflowdb.NodeOrDetachedNode +import io.shiftleft.semanticcpg.language.* -class NodeMethods(val node: NodeOrDetachedNode) extends AnyVal with NodeExtension { +class NodeMethods(val node: AbstractNode) extends AnyVal with NodeExtension { def location(implicit finder: NodeExtensionFinder): NewLocation = node match { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala index cc3b25194d16..4cc8d1d7d840 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala @@ -1,9 +1,11 @@ package io.shiftleft.semanticcpg.language.operatorextension -import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Identifier} -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Expression, Identifier} +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Call]) class ArrayAccessTraversal(val traversal: Iterator[OpNodes.ArrayAccess]) extends AnyVal { @Doc(info = "The expression representing the array") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala index f7751bedab93..31cf82c39a6b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala @@ -1,11 +1,12 @@ package io.shiftleft.semanticcpg.language.operatorextension +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Expression} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc -@help.Traversal(elementType = classOf[nodes.Call]) +@Traversal(elementType = classOf[Call]) class AssignmentTraversal(val traversal: Iterator[OpNodes.Assignment]) extends AnyVal { @Doc(info = "Left-hand sides of assignments") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala index 6bb06f0ce627..2a0a535afce9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala @@ -1,9 +1,11 @@ package io.shiftleft.semanticcpg.language.operatorextension -import io.shiftleft.codepropertygraph.generated.nodes.{FieldIdentifier, Member, TypeDecl} +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Member, TypeDecl} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Call]) class FieldAccessTraversal(val traversal: Iterator[OpNodes.FieldAccess]) extends AnyVal { @Doc(info = "Attempts to resolve the type declaration for this field access") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala index b3ef1b0ce155..5f63f72efef3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala @@ -2,22 +2,27 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Expression} -import io.shiftleft.semanticcpg.language.operatorextension.nodemethods._ +import io.shiftleft.semanticcpg.language.operatorextension.nodemethods.* trait Implicits { + implicit def toNodeTypeStartersOperatorExtension(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) implicit def toArrayAccessExt(arrayAccess: OpNodes.ArrayAccess): ArrayAccessMethods = new ArrayAccessMethods(arrayAccess) + implicit def toArrayAccessTrav(steps: Iterator[OpNodes.ArrayAccess]): ArrayAccessTraversal = new ArrayAccessTraversal(steps) implicit def toFieldAccessExt(fieldAccess: OpNodes.FieldAccess): FieldAccessMethods = new FieldAccessMethods(fieldAccess) + implicit def toFieldAccessTrav(steps: Iterator[OpNodes.FieldAccess]): FieldAccessTraversal = new FieldAccessTraversal(steps) - implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = new AssignmentMethods(assignment) + implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = + new AssignmentMethods(assignment) + implicit def toAssignmentTrav(steps: Iterator[OpNodes.Assignment]): AssignmentTraversal = new AssignmentTraversal(steps) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala index 5ddf05c29036..c012b035b506 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala @@ -1,11 +1,10 @@ package io.shiftleft.semanticcpg.language.operatorextension +import io.shiftleft.codepropertygraph.generated.help.{Doc, TraversalSource} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.{Doc, TraversalSource} -/** Steps that allow traversing from `cpg` to operators. - */ +/** Steps that allow traversing from `cpg` to operators. */ @TraversalSource class NodeTypeStarters(cpg: Cpg) { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala index d5c61829abb7..ab9ccd10a6af 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala @@ -1,9 +1,11 @@ package io.shiftleft.semanticcpg.language.operatorextension +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[AstNode]) class OpAstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { @Doc(info = "Any assignments that this node is a part of (traverse up)") diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala index 3a17608b98e3..4e8b700c8ae8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala @@ -1,9 +1,11 @@ package io.shiftleft.semanticcpg.language.operatorextension +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.Expression import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc +@Traversal(elementType = classOf[Expression]) class TargetTraversal(val traversal: Iterator[Expression]) extends AnyVal { @Doc( diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala index 8675591639b8..578c429f54ed 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala @@ -22,7 +22,7 @@ class ArrayAccessMethods(val arrayAccess: OpNodes.ArrayAccess) extends AnyVal { } def simpleName: Iterator[String] = { - arrayAccess.array match { + array match { case id: Identifier => Iterator.single(id.name) case _ => Iterator.empty } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala index 4faa072246f2..d470dbe92a12 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.operatorextension.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.Expression -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes class AssignmentMethods(val assignment: OpNodes.Assignment) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala index 1a301475a3eb..dc633f06599b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.operatorextension.nodemethods import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Expression} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.{OpNodes, allArrayAccessTypes} class TargetMethods(val expr: Expression) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala index 1917298f0a5f..888cd4d5f9c7 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala @@ -1,8 +1,10 @@ package io.shiftleft.semanticcpg +import flatgraph.help.DocSearchPackages +import io.shiftleft.codepropertygraph.generated import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.traversal.NodeTraversalImplicits +import io.shiftleft.semanticcpg.language.SarifExtension import io.shiftleft.semanticcpg.language.bindingextension.{ MethodTraversal as BindingMethodTraversal, TypeDeclTraversal as BindingTypeDeclTraversal @@ -10,16 +12,11 @@ import io.shiftleft.semanticcpg.language.bindingextension.{ import io.shiftleft.semanticcpg.language.callgraphextension.{CallTraversal, MethodTraversal} import io.shiftleft.semanticcpg.language.dotextension.{AstNodeDot, CfgNodeDot, InterproceduralNodeDot} import io.shiftleft.semanticcpg.language.nodemethods.* -import io.shiftleft.semanticcpg.language.types.expressions.generalizations.{ - AstNodeTraversal, - CfgNodeTraversal, - DeclarationTraversal, - ExpressionTraversal -} +import io.shiftleft.semanticcpg.language.types.expressions.generalizations.* import io.shiftleft.semanticcpg.language.types.expressions.{CallTraversal as OriginalCall, *} import io.shiftleft.semanticcpg.language.types.propertyaccessors.* import io.shiftleft.semanticcpg.language.types.structure.{MethodTraversal as OriginalMethod, *} -import overflowdb.NodeOrDetachedNode +import io.shiftleft.semanticcpg.language.types.structure.* /** Language for traversing the code property graph * @@ -27,20 +24,19 @@ import overflowdb.NodeOrDetachedNode * `steps` package, e.g. `Steps` */ package object language - extends operatorextension.Implicits + extends generated.language + with operatorextension.Implicits with modulevariable.Implicits with importresolver.Implicits - with LowPrioImplicits - with NodeTraversalImplicits { + with LowPrioImplicits { // Implicit conversions from generated node types. We use these to add methods // to generated node types. + implicit def cfgNodeToAstNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toExtendedNode(node: AbstractNode): NodeMethods = new NodeMethods(node) + implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = new StoredNodeMethods(node) + implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toExpressionMethods(node: Expression): ExpressionMethods = new ExpressionMethods(node) - implicit def cfgNodeToAsNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toExtendedNode(node: NodeOrDetachedNode): NodeMethods = new NodeMethods(node) - implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = new StoredNodeMethods(node) - implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) - implicit def toExpressionMethods(node: Expression): ExpressionMethods = new ExpressionMethods(node) implicit def toMethodMethods(node: Method): MethodMethods = new MethodMethods(node) implicit def toMethodReturnMethods(node: MethodReturn): MethodReturnMethods = new MethodReturnMethods(node) implicit def toCallMethods(node: Call): CallMethods = new CallMethods(node) @@ -68,8 +64,7 @@ package object language implicit def iterOnceToTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): TypeDeclTraversal = new TypeDeclTraversal(a.iterator) - implicit def iterOnceToOriginalCallTrav[A <: Call](a: IterableOnce[A]): OriginalCall = - new OriginalCall(a.iterator) + implicit def iterOnceToOriginalCallTrav(traversal: IterableOnce[Call]): OriginalCall = new OriginalCall(traversal) implicit def singleToControlStructureTrav[A <: ControlStructure](a: A): ControlStructureTraversal = new ControlStructureTraversal(Iterator.single(a)) @@ -110,8 +105,6 @@ package object language implicit def iterOnceToMethodParameterInTrav[A <: MethodParameterIn](a: IterableOnce[A]): MethodParameterTraversal = new MethodParameterTraversal(a.iterator) - implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A): MethodParameterOutTraversal = - new MethodParameterOutTraversal(Iterator.single(a)) implicit def iterOnceToMethodParameterOutTrav[A <: MethodParameterOut]( a: IterableOnce[A] ): MethodParameterOutTraversal = @@ -163,11 +156,6 @@ package object language implicit def iterOnceToBindingTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): BindingTypeDeclTraversal = new BindingTypeDeclTraversal(a.iterator) - implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = - new AstNodeDot(Iterator.single(a)) - implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = - new AstNodeDot(a.iterator) - implicit def singleToCfgNodeDot[A <: Method](a: A): CfgNodeDot = new CfgNodeDot(Iterator.single(a)) implicit def iterOnceToCfgNodeDot[A <: Method](a: IterableOnce[A]): CfgNodeDot = @@ -268,11 +256,32 @@ package object language implicit def toExpression[A <: Expression](a: IterableOnce[A]): ExpressionTraversal[A] = new ExpressionTraversal[A](a.iterator) + + object NonStandardImplicits { + + // note: this causes problems because MethodParameterOut has an `index` property and the `MethodParameterOutTraversal` defines an `index` step... + implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A): MethodParameterOutTraversal = + new MethodParameterOutTraversal(Iterator.single(a)) + + } + + implicit def singleToSarifTraversal[A <: Finding](a: A): SarifExtension = new SarifExtension(Iterator.single(a)) + implicit def iterOnceToSarifTraversal[A <: Finding](a: IterableOnce[A]): SarifExtension = new SarifExtension(a) } -trait LowPrioImplicits extends overflowdb.traversal.Implicits { - implicit def singleToCfgNodeTraversal[A <: CfgNode](a: A): CfgNodeTraversal[A] = - new CfgNodeTraversal[A](Iterator.single(a)) +trait LowPrioImplicits { + implicit val docSearchPackages: DocSearchPackages = + Cpg.defaultDocSearchPackage + .withAdditionalPackage("io.joern") + .withAdditionalPackage("io.shiftleft") + + implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = + new AstNodeDot(Iterator.single(a)) + implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = + new AstNodeDot(a.iterator) + + implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) + implicit def iterOnceToCfgNodeTraversal[A <: CfgNode](a: IterableOnce[A]): CfgNodeTraversal[A] = new CfgNodeTraversal[A](a.iterator) @@ -285,4 +294,5 @@ trait LowPrioImplicits extends overflowdb.traversal.Implicits { new DeclarationTraversal[A](Iterator.single(a)) implicit def iterOnceToDeclarationNodeTraversal[A <: Declaration](a: IterableOnce[A]): DeclarationTraversal[A] = new DeclarationTraversal[A](a.iterator) + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala index 33e47f74f375..5e87f6cebcc2 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala @@ -5,19 +5,20 @@ import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment import io.shiftleft.semanticcpg.language.operatorextension.allAssignmentTypes -/** A call site - */ +/** A call site. */ class CallTraversal(val traversal: Iterator[Call]) extends AnyVal { - /** Only statically dispatched calls - */ + /** Only statically dispatched calls */ def isStatic: Iterator[Call] = - traversal.dispatchType("STATIC_DISPATCH") + traversal.filter(_.isStatic) - /** Only dynamically dispatched calls - */ + /** Only dynamically dispatched calls */ def isDynamic: Iterator[Call] = - traversal.dispatchType("DYNAMIC_DISPATCH") + traversal.filter(_.isDynamic) + + /** Only dispatched calls inline */ + def isInline: Iterator[Call] = + traversal.filter(_.isInline) /** Only assignment calls */ diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala index f117e19b91e9..831e6358b5bf 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala @@ -1,15 +1,17 @@ package io.shiftleft.semanticcpg.language.types.expressions +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, ControlStructure, Expression} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Properties} +import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc object ControlStructureTraversal { val secondChildIndex = 2 val thirdChildIndex = 3 } +@Traversal(elementType = classOf[ControlStructure]) class ControlStructureTraversal(val traversal: Iterator[ControlStructure]) extends AnyVal { import ControlStructureTraversal.* @@ -23,11 +25,11 @@ class ControlStructureTraversal(val traversal: Iterator[ControlStructure]) exten @Doc(info = "Sub tree taken when condition evaluates to true") def whenTrue: Iterator[AstNode] = - traversal.out.has(Properties.Order, secondChildIndex: Int).cast[AstNode] + traversal.out.collectAll[AstNode].order(secondChildIndex) @Doc(info = "Sub tree taken when condition evaluates to false") def whenFalse: Iterator[AstNode] = - traversal.out.has(Properties.Order, thirdChildIndex).cast[AstNode] + traversal.out.collectAll[AstNode].order(thirdChildIndex) @Doc(info = "Only `Try` control structures") def isTry: Iterator[ControlStructure] = diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala index 8b11e5635c21..b9c6ea9f7e69 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.expressions import io.shiftleft.codepropertygraph.generated.nodes.{Declaration, Identifier} -import io.shiftleft.semanticcpg.language.toTraversalSugarExt +import io.shiftleft.semanticcpg.language.* /** An identifier, e.g., an instance of a local variable, or a temporary variable */ diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala index 33bfa563f290..7a66de691219 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala @@ -1,20 +1,18 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc -@help.Traversal(elementType = classOf[AstNode]) +@Traversal(elementType = classOf[AstNode]) class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { /** Nodes of the AST rooted in this node, including the node itself. */ @Doc(info = "All nodes of the abstract syntax tree") - def ast: Iterator[AstNode] = { - traversal.repeat(_.out(EdgeTypes.AST))(_.emit).cast[AstNode] - } + def ast: Iterator[AstNode] = + traversal.repeat(_._astOut)(_.emit).cast[AstNode] /** All nodes of the abstract syntax tree rooted in this node, which match `predicate`. Equivalent of `match` in the * original CPG paper. @@ -38,7 +36,7 @@ class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal /** Nodes of the AST rooted in this node, minus the node itself */ def astMinusRoot: Iterator[AstNode] = - traversal.repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst).cast[AstNode] + traversal.repeat(_._astOut)(_.emitAllButFirst).cast[AstNode] /** Direct children of node in the AST. Siblings are ordered by their `order` fields */ @@ -48,7 +46,7 @@ class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal /** Parent AST node */ def astParent: Iterator[AstNode] = - traversal.in(EdgeTypes.AST).cast[AstNode] + traversal._astIn.cast[AstNode] /** Siblings of this node in the AST, ordered by their `order` fields */ @@ -58,7 +56,7 @@ class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal /** Traverses up the AST and returns the first block node. */ def parentBlock: Iterator[Block] = - traversal.repeat(_.in(EdgeTypes.AST))(_.emit.until(_.hasLabel(NodeTypes.BLOCK))).collectAll[Block] + traversal.repeat(_._astIn)(_.emit.until(_.hasLabel(Block.Label))).collectAll[Block] /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached */ @@ -72,24 +70,26 @@ class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is reached */ - def inAst(root: AstNode): Iterator[AstNode] = + def inAst(root: AstNode): Iterator[AstNode] = { traversal - .repeat(_.in(EdgeTypes.AST))( + .repeat(_._astIn)( _.emit - .until(_.or(_.hasLabel(NodeTypes.METHOD), _.filter(n => root != null && root == n))) + .until(_.or(_.hasLabel(Method.Label), _.filter(n => root != null && root == n))) ) .cast[AstNode] + } /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is reached, minus this * node */ - def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = + def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = { traversal - .repeat(_.in(EdgeTypes.AST))( + .repeat(_._astIn)( _.emitAllButFirst - .until(_.or(_.hasLabel(NodeTypes.METHOD), _.filter(n => root != null && root == n))) + .until(_.or(_.hasLabel(Method.Label), _.filter(n => root != null && root == n))) ) .cast[AstNode] + } /** Traverse only to those AST nodes that are also control flow graph nodes */ @@ -208,10 +208,11 @@ class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal def isTypeDecl: Iterator[TypeDecl] = traversal.collectAll[TypeDecl] - def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = + def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = { traversal - .repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst.until(_.hasLabel(labels*))) + .repeat(_._astOut)(_.emitAllButFirst.until(_.hasLabel(labels*))) .dedup .cast[StoredNode] + } } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala index 3a3d22f7840f..99843e52b96f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala @@ -1,11 +1,12 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.neighboraccessors.Lang.* +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.help.Doc -@help.Traversal(elementType = classOf[CfgNode]) +@Traversal(elementType = classOf[CfgNode]) class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal { /** Textual representation of CFG node @@ -21,7 +22,6 @@ class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal /** Traverse to next expression in CFG. */ - @Doc(info = "Nodes directly reachable via outgoing CFG edges") def cfgNext: Iterator[CfgNode] = traversal._cfgOut diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala index f1075dad19c3..be7e4c5566cc 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala @@ -1,28 +1,27 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help -/** A declaration, such as a local or parameter. - */ -@help.Traversal(elementType = classOf[Declaration]) +/** A declaration, such as a local or parameter. */ +@Traversal(elementType = classOf[Declaration]) class DeclarationTraversal[NodeType <: Declaration](val traversal: Iterator[NodeType]) extends AnyVal { - /** The closure binding node referenced by this declaration - */ + /** The closure binding node referenced by this declaration */ + @Doc(info = "The closure binding node referenced by this declaration") def closureBinding: Iterator[ClosureBinding] = traversal.flatMap(_._refIn).collectAll[ClosureBinding] - /** Methods that capture this declaration - */ + /** Methods that capture this declaration */ + @Doc(info = "Methods that capture this declaration") def capturedByMethodRef: Iterator[MethodRef] = closureBinding.flatMap(_._captureIn).collectAll[MethodRef] - /** Types that capture this declaration - */ + /** Types that capture this declaration */ + @Doc(info = "Types that capture this declaration") def capturedByTypeRef: Iterator[TypeRef] = closureBinding.flatMap(_._captureIn).collectAll[TypeRef] - /** The parent method. - */ + /** The parent method. */ + @Doc(info = "The parent method.") def method: Iterator[Method] = traversal.flatMap { case x: Local => x.method case x: MethodParameterIn => x.method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala index 741f69202624..c3e22ff0dfc7 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala @@ -59,8 +59,8 @@ class ExpressionTraversal[NodeType <: Expression](val traversal: Iterator[NodeTy */ def method: Iterator[Method] = traversal._containsIn - .flatMap { - case x: Method => x.start + .map { + case x: Method => x case x: TypeDecl => x.astParent } .collectAll[Method] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala index 679b220aafaa..76b7d1672f61 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala @@ -1,10 +1,8 @@ package io.shiftleft.semanticcpg.language.types.propertyaccessors import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Modifier} -import io.shiftleft.codepropertygraph.generated.traversal.toModifierTraversalExtGen +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.* class ModifierAccessors[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala index a597e668bf1f..104ebab2ea2d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala @@ -1,34 +1,34 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes -import overflowdb.traversal._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* /** An (Java-) annotation, e.g., @Test. */ -class AnnotationTraversal(val traversal: Iterator[nodes.Annotation]) extends AnyVal { +class AnnotationTraversal(val traversal: Iterator[Annotation]) extends AnyVal { /** Traverse to parameter assignments */ - def parameterAssign: Iterator[nodes.AnnotationParameterAssign] = + def parameterAssign: Iterator[AnnotationParameterAssign] = traversal.flatMap(_._annotationParameterAssignViaAstOut) /** Traverse to methods annotated with this annotation. */ - def method: Iterator[nodes.Method] = + def method: Iterator[Method] = traversal.flatMap(_._methodViaAstIn) /** Traverse to type declarations annotated by this annotation */ - def typeDecl: Iterator[nodes.TypeDecl] = + def typeDecl: Iterator[TypeDecl] = traversal.flatMap(_._typeDeclViaAstIn) /** Traverse to member annotated by this annotation */ - def member: Iterator[nodes.Member] = + def member: Iterator[Member] = traversal.flatMap(_._memberViaAstIn) /** Traverse to parameter annotated by this annotation */ - def parameter: Iterator[nodes.MethodParameterIn] = + def parameter: Iterator[MethodParameterIn] = traversal.flatMap(_._methodParameterInViaAstIn) } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala index 36309ff44325..93dbb4f4ce87 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala @@ -1,9 +1,9 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes.Import -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -class DependencyTraversal(val traversal: Iterator[nodes.Dependency]) extends AnyVal { - def imports: Iterator[Import] = traversal.in(EdgeTypes.IMPORTS).cast[Import] +class DependencyTraversal(val traversal: Iterator[Dependency]) extends AnyVal { + def imports: Iterator[Import] = + traversal.importsIn } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala index c9aec5674e6f..f520f09db339 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala @@ -1,6 +1,6 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* /** A compilation unit diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala index e5131658039a..e46ff2b324e6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala @@ -1,6 +1,6 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Import, NamespaceBlock} +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class ImportTraversal(val traversal: Iterator[Import]) extends AnyVal { diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala index 26c73041412c..5319bfc24365 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala @@ -1,6 +1,6 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.semanticcpg.language.* @@ -13,8 +13,8 @@ class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal { def method: Iterator[Method] = { // TODO The following line of code is here for backwards compatibility. // Use the lower commented out line once not required anymore. - traversal.repeat(_.in(EdgeTypes.AST))(_.until(_.hasLabel(NodeTypes.METHOD))).cast[Method] - // definingBlock.method + traversal.repeat(_._astIn)(_.until(_.hasLabel(Method.Label))).cast[Method] +// definingBlock.method } } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala index d83a7062a256..5f45484923ea 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala @@ -1,7 +1,6 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated._ -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Member} +import io.shiftleft.codepropertygraph.generated.nodes.{Annotation, Call, Member} import io.shiftleft.semanticcpg.language.* /** A member variable of a class/type. @@ -10,7 +9,7 @@ class MemberTraversal(val traversal: Iterator[Member]) extends AnyVal { /** Traverse to annotations of member */ - def annotation: Iterator[nodes.Annotation] = + def annotation: Iterator[Annotation] = traversal.flatMap(_._annotationViaAstOut) /** Places where diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala index abb9778a1fd6..3ef123059d5f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala @@ -3,11 +3,12 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import scala.jdk.CollectionConverters.* - class MethodParameterOutTraversal(val traversal: Iterator[MethodParameterOut]) extends AnyVal { - def paramIn: Iterator[MethodParameterIn] = traversal.flatMap(_.parameterLinkIn.headOption) + def paramIn: Iterator[MethodParameterIn] = { + // TODO define a named step in schema + traversal.flatMap(_.parameterLinkIn.collectAll[MethodParameterIn]) + } /* method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ def index(num: Int): Iterator[MethodParameterOut] = @@ -27,9 +28,10 @@ class MethodParameterOutTraversal(val traversal: Iterator[MethodParameterOut]) e for { paramOut <- traversal method = paramOut.method - call <- method.callIn - arg <- call.argumentOut.collectAll[Expression] - if paramOut.parameterLinkIn.index.headOption.contains(arg.argumentIndex) + call <- method._callIn + arg <- call._argumentOut.collectAll[Expression] + // TODO define 'parameterLinkIn' as named step in schema + if paramOut.parameterLinkIn.collectAll[MethodParameterIn].index.headOption.contains(arg.argumentIndex) } yield arg } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala index 2f2af5ea8302..8d5a74d56ba3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala @@ -1,33 +1,32 @@ package io.shiftleft.semanticcpg.language.types.structure +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help import scala.jdk.CollectionConverters.* -/** Formal method input parameter - */ -@help.Traversal(elementType = classOf[MethodParameterIn]) +/** Formal method input parameter */ +@Traversal(elementType = classOf[MethodParameterIn]) class MethodParameterTraversal(val traversal: Iterator[MethodParameterIn]) extends AnyVal { - /** Traverse to parameter annotations - */ + /** Traverse to parameter annotations */ + @Doc(info = "Traverse to parameter annotations") def annotation: Iterator[Annotation] = traversal.flatMap(_._annotationViaAstOut) - /** Traverse to all parameters with index greater or equal than `num` - */ + /** Traverse to all parameters with index greater or equal than `num` */ + @Doc(info = "Traverse to all parameters with index greater or equal than `num`") def indexFrom(num: Int): Iterator[MethodParameterIn] = traversal.filter(_.index >= num) - /** Traverse to all parameters with index smaller or equal than `num` - */ + /** Traverse to all parameters with index smaller or equal than `num` */ + @Doc(info = "Traverse to all parameters with index smaller or equal than `num`") def indexTo(num: Int): Iterator[MethodParameterIn] = traversal.filter(_.index <= num) - /** Traverse to arguments (actual parameters) associated with this formal parameter - */ + /** Traverse to arguments (actual parameters) associated with this formal parameter */ + @Doc(info = "Traverse to arguments (actual parameters) associated with this formal parameter") def argument(implicit callResolver: ICallResolver): Iterator[Expression] = for { paramIn <- traversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala index a96b5fb9baf7..fb1dc0660ea1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala @@ -1,16 +1,16 @@ package io.shiftleft.semanticcpg.language.types.structure +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc -@help.Traversal(elementType = classOf[MethodReturn]) +@Traversal(elementType = classOf[MethodReturn]) class MethodReturnTraversal(val traversal: Iterator[MethodReturn]) extends AnyVal { @Doc(info = "traverse to parent method") def method: Iterator[Method] = - traversal.flatMap(_._methodViaAstIn) + traversal._methodViaAstIn def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = traversal.flatMap(_.returnUser) @@ -19,11 +19,11 @@ class MethodReturnTraversal(val traversal: Iterator[MethodReturn]) extends AnyVa */ @Doc(info = "traverse to last expressions in CFG (can be multiple)") def cfgLast: Iterator[CfgNode] = - traversal.flatMap(_.cfgIn) + traversal.cfgIn /** Traverse to return type */ @Doc(info = "traverse to return type") def typ: Iterator[Type] = - traversal.flatMap(_.evalTypeOut) + traversal.evalTypeOut } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala index ef31a200c906..e0ac06d6e979 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala @@ -1,15 +1,14 @@ package io.shiftleft.semanticcpg.language.types.structure +import io.shiftleft.codepropertygraph.generated.help.{Doc, Traversal} import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -import overflowdb.* -import overflowdb.traversal.help -import overflowdb.traversal.help.Doc +import io.shiftleft.codepropertygraph.generated.help.Doc /** A method, function, or procedure */ -@help.Traversal(elementType = classOf[Method]) +@Traversal(elementType = classOf[Method]) class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { /** Traverse to annotations of method @@ -129,11 +128,7 @@ class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { @Doc(info = "Top level expressions (\"Statements\")") def topLevelExpressions: Iterator[Expression] = - traversal._astOut - .collectAll[Block] - ._astOut - .not(_.collectAll[Local]) - .cast[Expression] + traversal.flatMap(_.topLevelExpressions) @Doc(info = "Control flow graph nodes") def cfgNode: Iterator[CfgNode] = @@ -164,7 +159,8 @@ class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { // some language frontends don't have a TYPE_DECL for a METHOD case Some(namespaceBlock: NamespaceBlock) => namespaceBlock.start // other language frontends always embed their method in a TYPE_DECL - case _ => m.definingTypeDecl.namespaceBlock + case _ => + m.definingTypeDecl.iterator.namespaceBlock } } } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala index 7f95dee3285a..c6aaecf5b1a3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala @@ -5,16 +5,17 @@ import io.shiftleft.semanticcpg.language.* class NamespaceBlockTraversal(val traversal: Iterator[NamespaceBlock]) extends AnyVal { - /** Namespaces for namespace blocks. + /** Namespaces for namespace blocks. TODO define a name in the schema */ def namespace: Iterator[Namespace] = - traversal.flatMap(_.refOut) + traversal.flatMap(_._namespaceViaRefOut) - /** The type declarations defined in this namespace + /** The type declarations defined in this namespace TODO define a name in the schema */ def typeDecl: Iterator[TypeDecl] = traversal.flatMap(_._typeDeclViaAstOut) + // TODO define a name in the schema def method: Iterator[Method] = traversal.flatMap(_._methodViaAstOut) } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala index c636047e2dff..3509c5f6a058 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala @@ -10,12 +10,12 @@ class NamespaceTraversal(val traversal: Iterator[Namespace]) extends AnyVal { /** The type declarations defined in this namespace */ def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.refIn).flatMap(_._typeDeclViaAstOut) + traversal.refIn.astOut.collectAll[TypeDecl] /** Methods defined in this namespace */ def method: Iterator[Method] = - traversal.flatMap(_.refIn).flatMap(_._methodViaAstOut) + traversal.refIn.astOut.collectAll[Method] /** External namespaces - any namespaces which contain one or more external type. */ diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala index 0df69620c7a0..ea97c42e32cb 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala @@ -1,6 +1,5 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* @@ -11,13 +10,13 @@ class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { /** Annotations of the type declaration */ - def annotation: Iterator[nodes.Annotation] = + def annotation: Iterator[Annotation] = traversal.flatMap(_._annotationViaAstOut) /** Types referencing to this type declaration. */ def referencingType: Iterator[Type] = - traversal.flatMap(_.refIn) + traversal.refIn /** Namespace in which this type declaration is defined */ @@ -57,7 +56,7 @@ class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { /** Direct and transitive base type declaration. */ def derivedTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst) + traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst.dedup) /** Direct base type declaration. */ @@ -67,7 +66,7 @@ class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { /** Direct and transitive base type declaration. */ def baseTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst) + traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst.dedup) /** Traverse to alias type declarations. */ @@ -105,7 +104,7 @@ class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { /** Direct and transitive alias type declarations. */ def aliasTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst) + traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst.dedup) def content: Iterator[String] = { traversal.flatMap(contentOnSingle) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala index ea2bf4e0e536..de589cac3d91 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala @@ -1,6 +1,5 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* @@ -8,7 +7,7 @@ class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal { /** Annotations of the corresponding type declaration. */ - def annotation: Iterator[nodes.Annotation] = + def annotation: Iterator[Annotation] = traversal.referencedTypeDecl.annotation /** Namespaces in which the corresponding type declaration is defined. @@ -44,7 +43,7 @@ class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal { /** Direct and transitive base types of the corresponding type declaration. */ def baseTypeTransitive: Iterator[Type] = - traversal.repeat(_.baseType)(_.emitAllButFirst) + traversal.repeat(_.baseType)(_.emitAllButFirst.dedup) /** Direct derived types. */ @@ -54,12 +53,12 @@ class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal { /** Direct and transitive derived types. */ def derivedTypeTransitive: Iterator[Type] = - traversal.repeat(_.derivedType)(_.emitAllButFirst) + traversal.repeat(_.derivedType)(_.emitAllButFirst.dedup) /** Type declarations which derive from this type. */ def derivedTypeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.inheritsFromIn) + traversal.inheritsFromIn /** Direct alias types. */ @@ -69,26 +68,27 @@ class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal { /** Direct and transitive alias types. */ def aliasTypeTransitive: Iterator[Type] = - traversal.repeat(_.aliasType)(_.emitAllButFirst) + traversal.repeat(_.aliasType)(_.emitAllButFirst.dedup) def localOfType: Iterator[Local] = - traversal.flatMap(_._localViaEvalTypeIn) + traversal._localViaEvalTypeIn def memberOfType: Iterator[Member] = - traversal.flatMap(_.evalTypeIn).collectAll[Member] + traversal.evalTypeIn.collectAll[Member] @deprecated("Please use `parameterOfType`") def parameter: Iterator[MethodParameterIn] = parameterOfType def parameterOfType: Iterator[MethodParameterIn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodParameterIn] + traversal.evalTypeIn.collectAll[MethodParameterIn] def methodReturnOfType: Iterator[MethodReturn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodReturn] + traversal.evalTypeIn.collectAll[MethodReturn] def expressionOfType: Iterator[Expression] = expression + // TODO define in schema def expression: Iterator[Expression] = - traversal.flatMap(_.evalTypeIn).collectAll[Expression] + traversal.evalTypeIn.collectAll[Expression] } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala index 582eabb8eba4..64aa3f553d22 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala @@ -1,9 +1,6 @@ package io.shiftleft.semanticcpg.layers -import better.files.File -import io.shiftleft.SerializedCpg import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.passes.CpgPassBase import io.shiftleft.semanticcpg.Overlays import org.slf4j.{Logger, LoggerFactory} @@ -36,19 +33,6 @@ abstract class LayerCreator { } } - protected def initSerializedCpg(outputDir: Option[String], passName: String, index: Int = 0): SerializedCpg = { - outputDir match { - case Some(dir) => new SerializedCpg((File(dir) / s"${index}_$passName").path.toAbsolutePath.toString) - case None => new SerializedCpg() - } - } - - protected def runPass(pass: CpgPassBase, context: LayerCreatorContext, index: Int = 0): Unit = { - val serializedCpg = initSerializedCpg(context.outputDir, pass.name, index) - pass.createApplySerializeAndStore(serializedCpg) - serializedCpg.close() - } - def create(context: LayerCreatorContext): Unit } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala index 9b0f25175138..a427ab3d9678 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala @@ -1,6 +1,6 @@ package io.shiftleft -import overflowdb.traversal.help.Table.AvailableWidthProvider +import flatgraph.help.Table.AvailableWidthProvider /** Domain specific language for querying code property graphs * diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifConfig.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifConfig.scala new file mode 100644 index 000000000000..97e0cff59b31 --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifConfig.scala @@ -0,0 +1,48 @@ +package io.shiftleft.semanticcpg.sarif + +import io.shiftleft.semanticcpg.sarif.SarifConfig.SarifVersion +import io.shiftleft.semanticcpg.sarif.v2_1_0.JoernScanResultToSarifConverter +import org.json4s.Serializer + +import java.net.URI + +/** A configuration for tool-specific information and arguments on transforming how findings are to be converted to + * SARIF. + * + * @param toolName + * The name of the tool component. + * @param toolFullName + * The name of the tool component along with its version and any other useful identifying information, such as its + * locale. + * @param toolInformationUri + * The absolute URI at which information about this version of the tool component can be found. + * @param organization + * The organization or company that produced the tool component. + * @param semanticVersion + * The tool component version in the format specified by Semantic Versioning 2.0. + * @param sarifVersion + * The SARIF format version of the resulting log file. + * @param resultConverter + * A transformer class to map from Finding nodes to a SARIF `Result`. + * @param customSerializers + * Additional JSON serializers for any additional properties for [[io.shiftleft.semanticcpg.sarif.Sarif]] derived + * classes. + */ +case class SarifConfig( + toolName: String = "Joern", + toolFullName: Option[String] = Option("Joern - The Bug Hunter's Workbench"), + toolInformationUri: Option[URI] = Option(URI("https://joern.io")), + organization: Option[String] = Option("Joern.io"), + semanticVersion: Option[String] = None, + sarifVersion: SarifVersion = SarifVersion.V2_1_0, + resultConverter: ScanResultToSarifConverter = JoernScanResultToSarifConverter(), + customSerializers: List[Serializer[?]] = SarifSchema.serializers +) + +object SarifConfig { + + enum SarifVersion { + case V2_1_0 + } + +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifSchema.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifSchema.scala new file mode 100644 index 000000000000..7ad7b1a54f2b --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/SarifSchema.scala @@ -0,0 +1,479 @@ +package io.shiftleft.semanticcpg.sarif + +import org.json4s.{CustomSerializer, Extraction, Serializer} +import org.slf4j.LoggerFactory + +import java.net.URI + +object SarifSchema { + + private val logger = LoggerFactory.getLogger(getClass) + + /** Provides a basic Sarif trait under which possibly multiple defined schemata would be defined. + */ + sealed trait Sarif { + + /** @return + * The SARIF format version of this log file. + */ + def version: String + + /** @return + * The URI of the JSON schema corresponding to the version. + */ + def schema: String + + /** @return + * The set of runs contained in this log file. + */ + def runs: List[Run] + } + + case class Sarif2_1_0(runs: List[Run]) extends Sarif { + def version: String = "2.1.0" + + def schema: String = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/errata01/os/schemas/sarif-schema-2.1.0.json" + } + + // Minimal properties we want to use across versions: + + /** Represents the contents of an artifact. + */ + trait ArtifactContent private[sarif] { + + /** @return + * UTF-8-encoded content from a text artifact. + */ + def text: String + } + + /** Specifies the location of an artifact. + */ + trait ArtifactLocation private[sarif] { + + /** @return + * A string containing a valid relative or absolute URI. + */ + def uri: Option[URI] + + /** @return + * A string which indirectly specifies the absolute URI with respect to which a relative URI in the "uri" + * property is interpreted. + */ + def uriBaseId: Option[String] + } + + /** A set of threadFlows which together describe a pattern of code execution relevant to detecting a result. + */ + trait CodeFlow private[sarif] { + + /** @return + * A message relevant to the code flow. + */ + def message: Option[Message] + + /** @return + * An array of one or more unique threadFlow objects, each of which describes the progress of a program through a + * thread of execution. + */ + def threadFlows: List[ThreadFlow] + } + + /** A location within a programming artifact. + */ + trait Location private[sarif] { + + /** @return + * Identifies the artifact and region. + */ + def physicalLocation: PhysicalLocation + } + + /** Encapsulates a message intended to be read by the end user. + */ + trait Message private[sarif] { + + /** @return + * A plain text message string. + */ + def text: String + + /** @return + * A Markdown message string. + */ + def markdown: Option[String] + } + + /** A physical location relevant to a result. Specifies a reference to a programming artifact together with a range of + * bytes or characters within that artifact. + */ + trait PhysicalLocation private[sarif] { + + /** @return + * The location of the artifact. + */ + def artifactLocation: ArtifactLocation + + /** @return + * Specifies a portion of the artifact. + */ + def region: Region + } + + /** A region within an artifact where a result was detected. + */ + trait Region private[sarif] { + + /** @return + * The line number of the first character in the region. + */ + def startLine: Option[Int] + + /** @return + * The column number of the first character in the region. + */ + def startColumn: Option[Int] + + /** @return + * The line number of the last character in the region. + */ + def endLine: Option[Int] + + /** @return + * The column number of the character following the end of the region. + */ + def endColumn: Option[Int] + + /** @return + * The portion of the artifact contents within the specified region. + */ + def snippet: Option[ArtifactContent] + + /** @return + * true if startLine is empty and larger than 0, as this is the main required property. + */ + def isEmpty: Boolean = startLine.forall(_ <= 0) + } + + /** Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime + * reporting. + */ + trait ReportingDescriptor private[sarif] { + + /** @return + * A stable, opaque identifier for the report. + */ + def id: String + + /** @return + * A report identifier that is understandable to an end user. + */ + def name: String + + /** @return + * A concise description of the report. Should be a single sentence that is understandable when visible space is + * limited to a single line of text. + */ + def shortDescription: Option[Message] + + /** @return + * A description of the report. Should, as far as possible, provide details sufficient to enable resolution of + * any problem indicated by the result. + */ + def fullDescription: Option[Message] + + /** @return + * A URI where the primary documentation for the report can be found. + */ + def helpUri: Option[URI] + + } + + /** A result produced by an analysis tool. + */ + trait Result private[sarif] { + + /** @return + * The stable, unique identifier of the rule, if any, to which this result is relevant. + */ + def ruleId: String + + /** @return + * A message that describes the result. The first sentence of the message only will be displayed when visible + * space is limited. + */ + def message: Message + + /** @return + * A value specifying the severity level of the result. + */ + def level: String + + /** @return + * The set of locations where the result was detected. Specify only one location unless the problem indicated by + * the result can only be corrected by making a change at every specified location. + */ + def locations: List[Location] + + /** @return + * A set of locations relevant to this result. + */ + def relatedLocations: List[Location] + + /** @return + * An array of 'codeFlow' objects relevant to the result. + */ + def codeFlows: List[CodeFlow] + + /** GitHub makes use of this property to track effectively the same finding across files between versions. + * @return + * A set of strings that contribute to the stable, unique identity of the result. + */ + def partialFingerprints: Map[String, String] + } + + /** Describes a single run of an analysis tool, and contains the reported output of that run. + */ + trait Run private[sarif] { + + /** @return + * Information about the tool or tool pipeline that generated the results in this run. A run can only contain + * results produced by a single tool or tool pipeline. A run can aggregate results from multiple log files, as + * long as context around the tool run (tool command-line arguments and the like) is identical for all aggregated + * files. + */ + def tool: Tool + + /** @return + * The set of results contained in an SARIF log. The results array can be omitted when a run is solely exporting + * rules metadata. It must be present (but may be empty) if a log file represents an actual scan. + */ + def results: List[Result] + + /** @return + * The artifact location specified by each uriBaseId symbol on the machine where the tool originally ran. + */ + def originalUriBaseIds: Map[String, ArtifactLocation] + } + + /** Describes a sequence of code locations that specify a path through a single thread of execution such as an + * operating system or fiber. + */ + trait ThreadFlow private[sarif] { + + /** @return + * A temporally ordered array of 'threadFlowLocation' objects, each of which describes a location visited by the + * tool while producing the result. + */ + def locations: List[ThreadFlowLocation] + } + + /** A location visited by an analysis tool while simulating or monitoring the execution of a program. + */ + trait ThreadFlowLocation private[sarif] { + + /** @return + * The code location. + */ + def location: Location + } + + /** The analysis tool that was run. + */ + trait Tool private[sarif] { + def driver: ToolComponent + } + + /** A component, such as a plug-in or the driver, of the analysis tool that was run. + */ + trait ToolComponent private[sarif] { + + /** @return + * The name of the tool component. + */ + def name: String + + /** @return + * The name of the tool component along with its version and any other useful identifying information, such as + * its locale. + */ + def fullName: Option[String] + + /** @return + * The organization or company that produced the tool component. + */ + def organization: Option[String] + + /** @return + * The tool component version in the format specified by Semantic Versioning 2.0. + */ + def semanticVersion: Option[String] + + /** @return + * The absolute URI at which information about this version of the tool component can be found. + */ + def informationUri: Option[URI] + + /** @return + * An array of reportingDescriptor objects relevant to the analysis performed by the tool component. + */ + def rules: List[ReportingDescriptor] + } + + /** A value specifying the severity level of the result. + */ + object Level { + val None = "none" + val Note = "note" + val Warning = "warning" + val Error = "error" + + def cvssToLevel(cvssScore: Double): String = { + cvssScore match { + case score if score < 0.0 || score > 10.0 => + logger.error(s"Score '$score' is not a valid CVSS score! Defaulting to 'warning' SARIF level.") + Warning + case score if score == 0.0 => None + case score if score <= 3.9 => Note + case score if score <= 6.9 => Warning + case score if score <= 10.0 => Error + } + } + + } + + val serializers: List[Serializer[?]] = List( + new CustomSerializer[SarifSchema.Sarif](implicit format => + ( + { case _ => + ??? + }, + { case sarif: SarifSchema.Sarif => + Extraction.decompose(Map("version" -> sarif.version, "$schema" -> sarif.schema, "runs" -> sarif.runs)) + } + ) + ), + new CustomSerializer[SarifSchema.ArtifactLocation](implicit format => + ( + { case _ => + ??? + }, + { case location: SarifSchema.ArtifactLocation => + val elementMap = Map.newBuilder[String, Any] + location.uri.foreach(x => elementMap.addOne("uri" -> x)) + elementMap.addOne("uriBaseId" -> location.uriBaseId) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[SarifSchema.CodeFlow](implicit format => + ( + { case _ => + ??? + }, + { case flow: SarifSchema.CodeFlow => + val elementMap = Map.newBuilder[String, Any] + flow.message.foreach(x => elementMap.addOne("message" -> x)) + elementMap.addOne("threadFlows" -> flow.threadFlows) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[SarifSchema.PhysicalLocation](implicit format => + ( + { case _ => + ??? + }, + { case location: SarifSchema.PhysicalLocation => + val elementMap = Map.newBuilder[String, Any] + elementMap.addOne("artifactLocation" -> location.artifactLocation) + if !location.region.isEmpty then elementMap.addOne("region" -> Extraction.decompose(location.region)) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[SarifSchema.Region](implicit format => + ( + { case _ => + ??? + }, + { case region: SarifSchema.Region => + val elementMap = Map.newBuilder[String, Any] + region.startLine.filterNot(x => x <= 0).foreach(x => elementMap.addOne("startLine" -> x)) + region.startColumn.filterNot(x => x <= 0).foreach(x => elementMap.addOne("startColumn" -> x)) + region.endLine.filterNot(x => x <= 0).foreach(x => elementMap.addOne("endLine" -> x)) + region.endColumn.filterNot(x => x <= 0).foreach(x => elementMap.addOne("endColumn" -> x)) + region.snippet.foreach(x => elementMap.addOne("snippet" -> x)) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[ReportingDescriptor](implicit format => + ( + { case _ => + ??? + }, + { case x: ReportingDescriptor => + val elementMap = Map.newBuilder[String, Any] + elementMap.addOne("id" -> x.id) + elementMap.addOne("name" -> x.name) + x.shortDescription.foreach(x => elementMap.addOne("shortDescription" -> x)) + x.fullDescription.foreach(x => elementMap.addOne("fullDescription" -> x)) + x.helpUri.foreach(x => elementMap.addOne("helpUri" -> x)) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[SarifSchema.Result](implicit format => + ( + { case _ => + ??? + }, + { case result: SarifSchema.Result => + val elementMap = Map.newBuilder[String, Any] + elementMap.addOne("ruleId" -> result.ruleId) + elementMap.addOne("message" -> result.message) + elementMap.addOne("level" -> result.level) + // Locations & related locations have no minimum, but do not allow duplicates + elementMap.addOne("locations" -> result.locations.distinct) + elementMap.addOne("relatedLocations" -> result.relatedLocations.distinct) + // codeFlows may be empty, but thread flows may not have empty arrays + elementMap.addOne("codeFlows" -> result.codeFlows.filterNot(_.threadFlows.isEmpty)) + + if result.partialFingerprints.nonEmpty then + elementMap.addOne("partialFingerprints" -> result.partialFingerprints) + + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[ToolComponent](implicit format => + ( + { case _ => + ??? + }, + { case x: ToolComponent => + val elementMap = Map.newBuilder[String, Any] + elementMap.addOne("name" -> x.name) + x.fullName.foreach(x => elementMap.addOne("fullName" -> x)) + x.organization.foreach(x => elementMap.addOne("organization" -> x)) + x.semanticVersion.foreach(x => elementMap.addOne("semanticVersion" -> x)) + x.informationUri.foreach(x => elementMap.addOne("informationUri" -> x)) + elementMap.addOne("rules" -> x.rules) + Extraction.decompose(elementMap.result()) + } + ) + ), + new CustomSerializer[URI](implicit format => + ( + { case _ => + ??? + }, + { case uri: URI => + Extraction.decompose(uri.toString) + } + ) + ) + ) + +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/ScanResultToSarifConverter.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/ScanResultToSarifConverter.scala new file mode 100644 index 000000000000..d1db972e11a0 --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/ScanResultToSarifConverter.scala @@ -0,0 +1,26 @@ +package io.shiftleft.semanticcpg.sarif + +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.sarif.SarifSchema.{ReportingDescriptor, Result} + +/** A component that converts a CPG finding to some version of SARIF. + */ +trait ScanResultToSarifConverter { + + /** Given a finding, will extract any rule data and create a SARIF ReportingDescriptor + * @param finding + * the finding to convert. + * @return + * a SARIF compliant reporting descriptor object if possible. + */ + def convertFindingToReportingDescriptor(finding: Finding): Option[ReportingDescriptor] + + /** Given a finding, will convert it to the SARIF specified result. + * @param finding + * the finding to convert. + * @return + * a SARIF compliant result object. + */ + def convertFindingToResult(finding: Finding): Result + +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/JoernScanResultToSarifConverter.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/JoernScanResultToSarifConverter.scala new file mode 100644 index 000000000000..bf83ec304249 --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/JoernScanResultToSarifConverter.scala @@ -0,0 +1,125 @@ +package io.shiftleft.semanticcpg.sarif.v2_1_0 + +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.{NodeExtensionFinder, *} +import io.shiftleft.semanticcpg.sarif.{SarifSchema, ScanResultToSarifConverter} + +import java.net.URI +import scala.util.Try + +/** Convert finding node to a SARIF v2.1.0 model. + */ +class JoernScanResultToSarifConverter extends ScanResultToSarifConverter { + + import JoernScanResultToSarifConverter.* + + override def convertFindingToReportingDescriptor(finding: Finding): Option[SarifSchema.ReportingDescriptor] = { + val description = createMessage(finding.description) + Option(Schema.ReportingDescriptor(id = finding.name, name = finding.title, fullDescription = Option(description))) + } + + override def convertFindingToResult(finding: Finding): SarifSchema.Result = { + val locations = finding.evidence.lastOption.map(nodeToLocation).toList + val relatedLocations = finding.evidence.headOption.map(nodeToLocation).toList + val codeFlows = evidenceToCodeFlow(finding) match { + case codeFlow if codeFlow.threadFlows.isEmpty => Nil + case codeFlow => codeFlow :: Nil + } + Schema.Result( + ruleId = finding.name, + message = Schema.Message(text = finding.title), + level = SarifSchema.Level.cvssToLevel(finding.score), + locations = locations, + relatedLocations = relatedLocations, + codeFlows = codeFlows + ) + } + + protected def evidenceToCodeFlow(finding: Finding): Schema.CodeFlow = { + val locations = finding.evidence.map(node => Schema.ThreadFlowLocation(location = nodeToLocation(node))).l + if (locations.isEmpty) { + Schema.CodeFlow(threadFlows = Nil) + } else { + Schema.CodeFlow(threadFlows = Schema.ThreadFlow(locations) :: Nil) + } + } + + protected def createMessage(text: String): Schema.Message = { + val plain = text.replace("`", "") // todo: use better markdown stripping + val markdown = Option(text).filterNot(_ == plain) // if these are equal, ignore + Schema.Message(text = plain, markdown = markdown) + } + + protected def nodeToLocation(node: StoredNode): Schema.Location = { + Schema.Location(physicalLocation = + Schema.PhysicalLocation( + artifactLocation = Schema.ArtifactLocation(uri = nodeToUri(node)), + region = nodeToRegion(node) + ) + ) + } + + protected def nodeToUri(node: StoredNode): Option[URI] = { + val fileNameOpt = node match { + case t: TypeDecl if !t.isExternal => Option(t.filename).filterNot(_ == "") + case m: Method if !m.isExternal => Option(m.filename).filterNot(_ == "") + case expr: Expression => expr.file.map(_.name).headOption + case _ => None + } + fileNameOpt.flatMap(x => Try(URI(x)).toOption) + } + + protected def nodeToRegion(node: StoredNode): Schema.Region = { + node match { + case t: TypeDecl => + Schema.Region( + startLine = t.lineNumber, + startColumn = t.columnNumber, + snippet = Option(Schema.ArtifactContent(t.code)) + ) + case m: Method => + Schema.Region( + startLine = m.lineNumber, + startColumn = m.columnNumber, + endLine = m.lineNumberEnd, + endColumn = m.columnNumberEnd, + snippet = Option(Schema.ArtifactContent(m.code)) + ) + case n: CfgNode => + Schema.Region( + startLine = n.lineNumber, + startColumn = n.columnNumber, + snippet = Option(Schema.ArtifactContent(n.code)) + ) + case _ => Schema.Region(None, None, None) + } + } + +} + +/** Due to module dependencies, the following code is lifted from `io.joern.console.scan`. + */ +object JoernScanResultToSarifConverter { + + private object FindingKeys { + val name = "name" + val title = "title" + val description = "description" + val score = "score" + } + + implicit class FindingExtension(val node: Finding) extends AnyRef { + + def name: String = getValue(FindingKeys.name) + + def title: String = getValue(FindingKeys.title) + + def description: String = getValue(FindingKeys.description).trim + + def score: Double = getValue(FindingKeys.score).toDoubleOption.getOrElse(-1d) + + protected def getValue(key: String, default: String = ""): String = + node.keyValuePairs.find(_.key == key).map(_.value).filterNot(_ == "-").getOrElse(default) + + } +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/Schema.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/Schema.scala new file mode 100644 index 000000000000..7f974dce37f9 --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/sarif/v2_1_0/Schema.scala @@ -0,0 +1,77 @@ +package io.shiftleft.semanticcpg.sarif.v2_1_0 + +import io.shiftleft.semanticcpg.sarif.SarifSchema +import io.shiftleft.semanticcpg.sarif.SarifSchema.Location +import org.json4s.{CustomSerializer, Extraction} + +import java.net.URI + +object Schema { + + final case class ArtifactContent(text: String) extends SarifSchema.ArtifactContent + + /** Specifies the location of an artifact. + * + * @param uri + * A string containing a valid relative or absolute URI. + * @param uriBaseId + * A string which indirectly specifies the absolute URI with respect to which a relative URI in the "uri" property + * is interpreted. + */ + final case class ArtifactLocation(uri: Option[URI] = None, uriBaseId: Option[String] = Option("PROJECT_ROOT")) + extends SarifSchema.ArtifactLocation + + final case class CodeFlow(threadFlows: List[ThreadFlow], message: Option[Message] = None) extends SarifSchema.CodeFlow + + final case class Location(physicalLocation: PhysicalLocation) extends SarifSchema.Location + + final case class Message(text: String, markdown: Option[String] = None) extends SarifSchema.Message + + final case class PhysicalLocation(artifactLocation: ArtifactLocation, region: Region) + extends SarifSchema.PhysicalLocation + + final case class Region( + startLine: Option[Int], + startColumn: Option[Int] = None, + endLine: Option[Int] = None, + endColumn: Option[Int] = None, + snippet: Option[ArtifactContent] = None + ) extends SarifSchema.Region + + final case class ReportingDescriptor( + id: String, + name: String, + shortDescription: Option[Message] = None, + fullDescription: Option[Message] = None, + helpUri: Option[URI] = None + ) extends SarifSchema.ReportingDescriptor + + final case class Result( + ruleId: String, + message: Message, + level: String, + locations: List[Location], + relatedLocations: List[Location], + codeFlows: List[CodeFlow], + partialFingerprints: Map[String, String] = Map.empty + ) extends SarifSchema.Result + + final case class Run(tool: Tool, results: List[SarifSchema.Result], originalUriBaseIds: Map[String, ArtifactLocation]) + extends SarifSchema.Run + + final case class ThreadFlow(locations: List[ThreadFlowLocation]) extends SarifSchema.ThreadFlow + + final case class ThreadFlowLocation(location: Location) extends SarifSchema.ThreadFlowLocation + + final case class Tool(driver: ToolComponent) extends SarifSchema.Tool + + final case class ToolComponent( + name: String, + fullName: Option[String] = None, + organization: Option[String] = None, + semanticVersion: Option[String] = None, + informationUri: Option[URI] = None, + rules: List[SarifSchema.ReportingDescriptor] = Nil + ) extends SarifSchema.ToolComponent + +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala index 3c5a80dabddc..b523d8893fd3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala @@ -1,52 +1,7 @@ package io.shiftleft.semanticcpg.testing import io.shiftleft.codepropertygraph.generated.nodes.StoredNode -import overflowdb.{Edge, Node, Property, PropertyKey} -import java.util - -/** mixin trait for test nodes */ trait DummyNodeImpl extends StoredNode { - // Members declared in overflowdb.Element - def graph(): overflowdb.Graph = ??? - def property[A](x$1: overflowdb.PropertyKey[A]): A = ??? - def property(x$1: String): Object = ??? - def propertyKeys(): java.util.Set[String] = ??? - def propertiesMap(): java.util.Map[String, Object] = ??? - def propertyOption(x$1: String): java.util.Optional[Object] = ??? - def propertyOption[A](x$1: overflowdb.PropertyKey[A]): java.util.Optional[A] = ??? - override def addEdgeImpl(label: String, inNode: Node, keyValues: Any*): Edge = ??? - override def addEdgeImpl(label: String, inNode: Node, keyValues: util.Map[String, AnyRef]): Edge = ??? - override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: Any*): Unit = ??? - override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: util.Map[String, AnyRef]): Unit = ??? - override def setPropertyImpl(key: String, value: Any): Unit = ??? - override def setPropertyImpl[A](key: PropertyKey[A], value: A): Unit = ??? - override def setPropertyImpl(property: Property[?]): Unit = ??? - override def removePropertyImpl(key: String): Unit = ??? - override def removeImpl(): Unit = ??? - - // Members declared in scala.Equals - def canEqual(that: Any): Boolean = ??? - - def both(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def both(): java.util.Iterator[overflowdb.Node] = ??? - def bothE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def bothE(): java.util.Iterator[overflowdb.Edge] = ??? - def id(): Long = ??? - def in(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def in(): java.util.Iterator[overflowdb.Node] = ??? - def inE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def inE(): java.util.Iterator[overflowdb.Edge] = ??? - def out(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def out(): java.util.Iterator[overflowdb.Node] = ??? - def outE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def outE(): java.util.Iterator[overflowdb.Edge] = ??? - - // Members declared in scala.Product - def productArity: Int = ??? - def productElement(n: Int): Any = ??? - - // Members declared in io.shiftleft.codepropertygraph.generated.nodes.StoredNode - def productElementLabel(n: Int): String = ??? - def valueMap: java.util.Map[String, AnyRef] = ??? + def propertiesMap: java.util.Map[String, Any] = ??? } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/MockCpg.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/MockCpg.scala new file mode 100644 index 000000000000..96ba88e2ba73 --- /dev/null +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/MockCpg.scala @@ -0,0 +1,229 @@ +package io.shiftleft.semanticcpg.testing + +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Languages, ModifierTypes} +import io.shiftleft.passes.CpgPass +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.DiffGraphBuilder + +object MockCpg { + + def apply(): MockCpg = new MockCpg + + def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) +} + +case class MockCpg(cpg: Cpg = Cpg.emptyCpg) { + + def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) + + def withMetaData(language: String, overlays: List[String]): MockCpg = { + withCustom { (diffGraph, _) => + diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) + } + } + + def withFile(filename: String, content: Option[String] = None): MockCpg = + withCustom { (graph, _) => + val newFile = NewFile().name(filename) + content.foreach(newFile.content(_)) + graph.addNode(newFile) + } + + def withNamespace(name: String, inFile: Option[String] = None): MockCpg = + withCustom { (graph, _) => + { + val namespaceBlock = NewNamespaceBlock().name(name) + val namespace = NewNamespace().name(name) + graph.addNode(namespaceBlock) + graph.addNode(namespace) + graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) + if (inFile.isDefined) { + val fileNode = cpg.file(inFile.get).head + graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) + } + } + } + + def withTypeDecl( + name: String, + isExternal: Boolean = false, + inNamespace: Option[String] = None, + inFile: Option[String] = None, + offset: Option[Int] = None, + offsetEnd: Option[Int] = None + ): MockCpg = + withCustom { (graph, _) => + { + val typeNode = NewType().name(name) + val typeDeclNode = NewTypeDecl() + .name(name) + .fullName(name) + .isExternal(isExternal) + + offset.foreach(typeDeclNode.offset(_)) + offsetEnd.foreach(typeDeclNode.offsetEnd(_)) + + val member = NewMember().name("amember") + val modifier = NewModifier().modifierType(ModifierTypes.STATIC) + + graph.addNode(typeDeclNode) + graph.addNode(typeNode) + graph.addNode(member) + graph.addNode(modifier) + graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) + graph.addEdge(typeDeclNode, member, EdgeTypes.AST) + graph.addEdge(member, modifier, EdgeTypes.AST) + + if (inNamespace.isDefined) { + val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head + graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) + } + if (inFile.isDefined) { + val fileNode = cpg.file(inFile.get).head + graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) + } + } + } + + def withMethod( + name: String, + external: Boolean = false, + inTypeDecl: Option[String] = None, + fileName: String = "", + offset: Option[Int] = None, + offsetEnd: Option[Int] = None + ): MockCpg = + withCustom { (graph, _) => + val retParam = NewMethodReturn().typeFullName("int").order(10) + val param = NewMethodParameterIn().order(1).index(1).name("param1") + val paramType = NewType().name("paramtype") + val paramOut = NewMethodParameterOut().name("param1").order(1) + val method = + NewMethod().isExternal(external).name(name).fullName(name).signature("asignature").filename(fileName) + offset.foreach(method.offset(_)) + offsetEnd.foreach(method.offsetEnd(_)) + val block = NewBlock().typeFullName("int") + val modifier = NewModifier().modifierType("modifiertype") + + graph.addNode(method) + graph.addNode(retParam) + graph.addNode(param) + graph.addNode(paramType) + graph.addNode(paramOut) + graph.addNode(block) + graph.addNode(modifier) + graph.addEdge(method, retParam, EdgeTypes.AST) + graph.addEdge(method, param, EdgeTypes.AST) + graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) + graph.addEdge(method, block, EdgeTypes.AST) + graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(method, modifier, EdgeTypes.AST) + + if (inTypeDecl.isDefined) { + val typeDeclNode = cpg.typeDecl(inTypeDecl.get).head + graph.addEdge(typeDeclNode, method, EdgeTypes.AST) + } + + if (fileName != "") { + val file = cpg.file + .nameExact(fileName) + .headOption + .getOrElse(throw new RuntimeException(s"file with name='$fileName' not found")) + graph.addEdge(method, file, EdgeTypes.SOURCE_FILE) + } + } + + def withTagsOnMethod( + methodName: String, + methodTags: List[(String, String)] = List(), + paramTags: List[(String, String)] = List() + ): MockCpg = + withCustom { (graph, cpg) => + implicit val diffGraph: DiffGraphBuilder = graph + methodTags.foreach { case (k, v) => + cpg.method(methodName).newTagNodePair(k, v).store()(diffGraph) + } + paramTags.foreach { case (k, v) => + cpg.method(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) + } + } + + def withCallInMethod(methodName: String, callName: String, code: Option[String] = None): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method(methodName).head + val blockNode = methodNode.block + val callNode = NewCall().name(callName).code(code.getOrElse(callName)) + graph.addNode(callNode) + graph.addEdge(blockNode, callNode, EdgeTypes.AST) + graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) + } + + def withMethodCall(calledMethod: String, callingMethod: String, code: Option[String] = None): MockCpg = + withCustom { (graph, cpg) => + val callingMethodNode = cpg.method(callingMethod).head + val calledMethodNode = cpg.method(calledMethod).head + val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) + graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) + graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) + } + + def withLocalInMethod(methodName: String, localName: String): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method(methodName).head + val blockNode = methodNode.block + val typeNode = NewType().name("alocaltype") + val localNode = NewLocal().name(localName).typeFullName("alocaltype") + graph.addNode(localNode) + graph.addNode(typeNode) + graph.addEdge(blockNode, localNode, EdgeTypes.AST) + graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) + } + + def withLiteralArgument(callName: String, literalCode: String): MockCpg = { + withCustom { (graph, cpg) => + val callNode = cpg.call(callName).head + val methodNode = callNode.method + val literalNode = NewLiteral().code(literalCode) + val typeDecl = NewTypeDecl() + .name("ATypeDecl") + .fullName("ATypeDecl") + + graph.addNode(typeDecl) + graph.addNode(literalNode) + graph.addEdge(callNode, literalNode, EdgeTypes.AST) + graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) + } + } + + def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = + withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) + + def withCallArgument(callName: String, callArgName: String, code: String = "", index: Int = 1): MockCpg = + withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) + + def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => + val callNode = cpg.call(callName).head + val methodNode = callNode.method + val typeDecl = NewTypeDecl().name("abc") + graph.addEdge(callNode, newNode, EdgeTypes.AST) + graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) + graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) + graph.addEdge(newNode, typeDecl, EdgeTypes.REF) + graph.addNode(newNode) + } + + def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = { + val diffGraph = new DiffGraphBuilder(cpg.graph.schema) + f(diffGraph, cpg) + class MyPass extends CpgPass(cpg) { + override def run(builder: DiffGraphBuilder): Unit = { + builder.absorb(diffGraph) + } + } + new MyPass().createAndApply() + this + } +} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala deleted file mode 100644 index 1207a9efa669..000000000000 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala +++ /dev/null @@ -1,234 +0,0 @@ -package io.shiftleft.semanticcpg - -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Languages, ModifierTypes} -import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ -import overflowdb.BatchedUpdate -import overflowdb.BatchedUpdate.DiffGraphBuilder - -package object testing { - - object MockCpg { - - def apply(): MockCpg = new MockCpg - - def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) - } - - case class MockCpg(cpg: Cpg = Cpg.empty) { - - def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) - - def withMetaData(language: String, overlays: List[String]): MockCpg = { - withCustom { (diffGraph, _) => - diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) - } - } - - def withFile(filename: String, content: Option[String] = None): MockCpg = - withCustom { (graph, _) => - val newFile = NewFile().name(filename) - content.foreach(newFile.content(_)) - graph.addNode(newFile) - } - - def withNamespace(name: String, inFile: Option[String] = None): MockCpg = - withCustom { (graph, _) => - { - val namespaceBlock = NewNamespaceBlock().name(name) - val namespace = NewNamespace().name(name) - graph.addNode(namespaceBlock) - graph.addNode(namespace) - graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) - if (inFile.isDefined) { - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) - } - } - } - - def withTypeDecl( - name: String, - isExternal: Boolean = false, - inNamespace: Option[String] = None, - inFile: Option[String] = None, - offset: Option[Int] = None, - offsetEnd: Option[Int] = None - ): MockCpg = - withCustom { (graph, _) => - { - val typeNode = NewType().name(name) - val typeDeclNode = NewTypeDecl() - .name(name) - .fullName(name) - .isExternal(isExternal) - - offset.foreach(typeDeclNode.offset(_)) - offsetEnd.foreach(typeDeclNode.offsetEnd(_)) - - val member = NewMember().name("amember") - val modifier = NewModifier().modifierType(ModifierTypes.STATIC) - - graph.addNode(typeDeclNode) - graph.addNode(typeNode) - graph.addNode(member) - graph.addNode(modifier) - graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) - graph.addEdge(typeDeclNode, member, EdgeTypes.AST) - graph.addEdge(member, modifier, EdgeTypes.AST) - - if (inNamespace.isDefined) { - val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head - graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) - } - if (inFile.isDefined) { - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) - } - } - } - - def withMethod( - name: String, - external: Boolean = false, - inTypeDecl: Option[String] = None, - fileName: String = "", - offset: Option[Int] = None, - offsetEnd: Option[Int] = None - ): MockCpg = - withCustom { (graph, _) => - val retParam = NewMethodReturn().typeFullName("int").order(10) - val param = NewMethodParameterIn().order(1).index(1).name("param1") - val paramType = NewType().name("paramtype") - val paramOut = NewMethodParameterOut().name("param1").order(1) - val method = - NewMethod().isExternal(external).name(name).fullName(name).signature("asignature").filename(fileName) - offset.foreach(method.offset(_)) - offsetEnd.foreach(method.offsetEnd(_)) - val block = NewBlock().typeFullName("int") - val modifier = NewModifier().modifierType("modifiertype") - - graph.addNode(method) - graph.addNode(retParam) - graph.addNode(param) - graph.addNode(paramType) - graph.addNode(paramOut) - graph.addNode(block) - graph.addNode(modifier) - graph.addEdge(method, retParam, EdgeTypes.AST) - graph.addEdge(method, param, EdgeTypes.AST) - graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) - graph.addEdge(method, block, EdgeTypes.AST) - graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(method, modifier, EdgeTypes.AST) - - if (inTypeDecl.isDefined) { - val typeDeclNode = cpg.typeDecl.name(inTypeDecl.get).head - graph.addEdge(typeDeclNode, method, EdgeTypes.AST) - } - - if (fileName != "") { - val file = cpg.file - .nameExact(fileName) - .headOption - .getOrElse(throw new RuntimeException(s"file with name='$fileName' not found")) - graph.addEdge(method, file, EdgeTypes.SOURCE_FILE) - } - } - - def withTagsOnMethod( - methodName: String, - methodTags: List[(String, String)] = List(), - paramTags: List[(String, String)] = List() - ): MockCpg = - withCustom { (graph, cpg) => - implicit val diffGraph: DiffGraphBuilder = graph - methodTags.foreach { case (k, v) => - cpg.method.name(methodName).newTagNodePair(k, v).store()(diffGraph) - } - paramTags.foreach { case (k, v) => - cpg.method.name(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) - } - } - - def withCallInMethod(methodName: String, callName: String, code: Option[String] = None): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val callNode = NewCall().name(callName).code(code.getOrElse(callName)) - graph.addNode(callNode) - graph.addEdge(blockNode, callNode, EdgeTypes.AST) - graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) - } - - def withMethodCall(calledMethod: String, callingMethod: String, code: Option[String] = None): MockCpg = - withCustom { (graph, cpg) => - val callingMethodNode = cpg.method.name(callingMethod).head - val calledMethodNode = cpg.method.name(calledMethod).head - val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) - graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) - graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) - } - - def withLocalInMethod(methodName: String, localName: String): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val typeNode = NewType().name("alocaltype") - val localNode = NewLocal().name(localName).typeFullName("alocaltype") - graph.addNode(localNode) - graph.addNode(typeNode) - graph.addEdge(blockNode, localNode, EdgeTypes.AST) - graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) - } - - def withLiteralArgument(callName: String, literalCode: String): MockCpg = { - withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val literalNode = NewLiteral().code(literalCode) - val typeDecl = NewTypeDecl() - .name("ATypeDecl") - .fullName("ATypeDecl") - - graph.addNode(typeDecl) - graph.addNode(literalNode) - graph.addEdge(callNode, literalNode, EdgeTypes.AST) - graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) - } - } - - def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = - withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) - - def withCallArgument(callName: String, callArgName: String, code: String = "", index: Int = 1): MockCpg = - withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) - - def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val typeDecl = NewTypeDecl().name("abc") - graph.addEdge(callNode, newNode, EdgeTypes.AST) - graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) - graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) - graph.addEdge(newNode, typeDecl, EdgeTypes.REF) - graph.addNode(newNode) - } - - def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = { - val diffGraph = Cpg.newDiffGraphBuilder - f(diffGraph, cpg) - class MyPass extends CpgPass(cpg) { - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = { - builder.absorb(diffGraph) - } - } - new MyPass().createAndApply() - this - } - } - -} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala index e121e3830b91..78b76d838b71 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.utils import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* object Statements { def countAll(cpg: Cpg): Long = diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/NewNodeStepsTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/NewNodeStepsTests.scala index d47661adee88..da5b1b4e803a 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/NewNodeStepsTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/NewNodeStepsTests.scala @@ -1,15 +1,13 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} +import flatgraph.DiffGraphApplier.applyDiff +import io.shiftleft.codepropertygraph.generated.nodes.* import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb.BatchedUpdate.{DiffGraphBuilder, applyDiff} - -import scala.jdk.CollectionConverters._ class NewNodeStepsTest extends AnyWordSpec with Matchers { - import io.shiftleft.semanticcpg.language.NewNodeNodeStepsTest._ + import io.shiftleft.semanticcpg.language.NewNodeNodeStepsTest.* "stores NewNodes" in { implicit val diffGraphBuilder: DiffGraphBuilder = Cpg.newDiffGraphBuilder @@ -17,9 +15,9 @@ class NewNodeStepsTest extends AnyWordSpec with Matchers { val cpg = Cpg.empty new NewNodeSteps(newNode.start).store() - cpg.graph.nodes.toList.size shouldBe 0 + cpg.all.size shouldBe 0 applyDiff(cpg.graph, diffGraphBuilder) - cpg.graph.nodes.toList.size shouldBe 1 + cpg.all.size shouldBe 1 } "can access the node label" in { @@ -29,17 +27,19 @@ class NewNodeStepsTest extends AnyWordSpec with Matchers { "stores containedNodes and connecting edge" when { "embedding a StoredNode and a NewNode" in { - implicit val diffGraphBuilder: DiffGraphBuilder = Cpg.newDiffGraphBuilder - val cpg = Cpg.empty - val existingContainedNode = cpg.graph.addNode(42L, "MODIFIER").asInstanceOf[StoredNode] - cpg.graph.V().asScala.toSet shouldBe Set(existingContainedNode) + val cpg = Cpg.empty + val newModifier = NewModifier() + applyDiff(cpg.graph, Cpg.newDiffGraphBuilder.addNode(newModifier)) + val existingContainedNode = newModifier.storedRef.get + cpg.graph.allNodes.toSet shouldBe Set(existingContainedNode) - val newContainedNode = newTestNode() - val newNode = newTestNode(evidence = List(existingContainedNode, newContainedNode)) + implicit val diffGraphBuilder: DiffGraphBuilder = Cpg.newDiffGraphBuilder + val newContainedNode = newTestNode() + val newNode = newTestNode(evidence = List(existingContainedNode, newContainedNode)) new NewNodeSteps(newNode.start).store() - cpg.graph.V().asScala.length shouldBe 1 + cpg.all.length shouldBe 1 applyDiff(cpg.graph, diffGraphBuilder) - cpg.graph.V().asScala.length shouldBe 3 + cpg.all.length shouldBe 3 } "embedding a NewNode recursively" in { @@ -49,9 +49,9 @@ class NewNodeStepsTest extends AnyWordSpec with Matchers { val newContainedNodeL0 = newTestNode(evidence = List(newContainedNodeL1)) val newNode = newTestNode(evidence = List(newContainedNodeL0)) new NewNodeSteps(newNode.start).store() - cpg.graph.V().asScala.size shouldBe 0 + cpg.all.size shouldBe 0 applyDiff(cpg.graph, diffGraphBuilder) - cpg.graph.V().asScala.size shouldBe 3 + cpg.all.size shouldBe 3 } } diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/OverlaysTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/OverlaysTests.scala similarity index 96% rename from semanticcpg/src/test/scala/io/shiftleft/semanticcpg/OverlaysTests.scala rename to semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/OverlaysTests.scala index f1fcf8098a71..f11d33aed67f 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/OverlaysTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/OverlaysTests.scala @@ -1,6 +1,6 @@ package io.shiftleft.semanticcpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/SarifTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/SarifTests.scala new file mode 100644 index 000000000000..ec49c1413e46 --- /dev/null +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/SarifTests.scala @@ -0,0 +1,364 @@ +package io.shiftleft.semanticcpg.language + +import flatgraph.DiffGraphApplier +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.{NewFinding, NewKeyValuePair, NewMethod} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SarifTests extends AnyWordSpec with Matchers { + + import SarifTests.* + + "a CPG without finding nodes" should { + val cpg = Cpg.empty + + "create a SARIF file with empty results" in { + val sarif = cpg.finding.toSarif + sarif.version shouldBe "2.1.0" + sarif.schema shouldBe "https://docs.oasis-open.org/sarif/sarif/v2.1.0/errata01/os/schemas/sarif-schema-2.1.0.json" + sarif.runs.size shouldBe 1 + val run = sarif.runs.head + run.results shouldBe Nil + val tool = run.tool.driver + tool.name shouldBe "Joern" + tool.fullName shouldBe Option("Joern - The Bug Hunter's Workbench") + tool.organization shouldBe Option("Joern.io") + } + } + + "an iterable with a single finding node with all expected properties" should { + + val cpg = Cpg.empty + + createValidFindingNode(cpg) + + "create a valid SARIF result" in { + val sarif = cpg.finding.toSarif() + val run = sarif.runs.head + val rules = run.tool.driver.rules + + rules.size shouldBe 1 + val rule = rules.head + rule.id shouldBe "f1" + rule.name shouldBe "Rule 1" + rule.shortDescription shouldBe None + rule.fullDescription.map(_.text) shouldBe Some("something bad happened") + rule.helpUri shouldBe None + + val results = run.results + results.size shouldBe 1 + + val result = results.head + + result.ruleId shouldBe "f1" + result.message.text shouldBe "Rule 1" + result.level shouldBe "error" + + val region = result.locations.head.physicalLocation.region + + region.startLine shouldBe Some(2) + region.snippet.map(_.text) shouldBe Some("public foo()") + + val artifactLocation = result.locations.head.physicalLocation.artifactLocation + artifactLocation.uri.map(_.toString) shouldBe Some("Bar.java") + + result.codeFlows.size shouldBe 1 + result.codeFlows.head.message shouldBe None + } + + "create a valid SARIF JSON" in { + cpg.finding.toSarifJson(pretty = true) shouldBe + """{ + | "version":"2.1.0", + | "$schema":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/errata01/os/schemas/sarif-schema-2.1.0.json", + | "runs":[ + | { + | "tool":{ + | "driver":{ + | "organization":"Joern.io", + | "name":"Joern", + | "informationUri":"https://joern.io", + | "fullName":"Joern - The Bug Hunter's Workbench", + | "rules":[ + | { + | "id":"f1", + | "name":"Rule 1", + | "fullDescription":{ + | "text":"something bad happened" + | } + | } + | ] + | } + | }, + | "results":[ + | { + | "locations":[ + | { + | "physicalLocation":{ + | "artifactLocation":{ + | "uri":"Bar.java", + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | ], + | "relatedLocations":[ + | { + | "physicalLocation":{ + | "artifactLocation":{ + | "uri":"Bar.java", + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | ], + | "message":{ + | "text":"Rule 1" + | }, + | "codeFlows":[ + | { + | "threadFlows":[ + | { + | "locations":[ + | { + | "location":{ + | "physicalLocation":{ + | "artifactLocation":{ + | "uri":"Bar.java", + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | } + | ] + | } + | ] + | } + | ], + | "ruleId":"f1", + | "level":"error" + | } + | ], + | "originalUriBaseIds":{ + | "PROJECT_ROOT":{ + | "uriBaseId":"" + | } + | } + | } + | ] + |} + | + |""".stripMargin.trim + } + + } + + "an iterable with a single finding node with missing properties" should { + + val cpg = Cpg.empty + + createInvalidFindingNode(cpg) + + "create a valid SARIF result" in { + val sarif = cpg.finding.toSarif() + val run = sarif.runs.head + val rules = run.tool.driver.rules + + rules.size shouldBe 1 + val rule = rules.head + rule.id shouldBe "f1" + rule.name shouldBe "" + rule.shortDescription shouldBe None + rule.fullDescription.map(_.text) shouldBe Some("something bad happened") + rule.helpUri shouldBe None + + val results = run.results + results.size shouldBe 1 + val result = results.head + + result.ruleId shouldBe "f1" + result.message.text shouldBe "" + result.level shouldBe "warning" + + val region = result.locations.head.physicalLocation.region + + region.startLine shouldBe Some(2) + region.snippet.map(_.text) shouldBe Some("public foo()") + + val artifactLocation = result.locations.head.physicalLocation.artifactLocation + artifactLocation.uri.map(_.toString) shouldBe None + + result.codeFlows.size shouldBe 1 + result.codeFlows.head.message shouldBe None + } + + "create a valid SARIF JSON" in { + cpg.finding.toSarifJson(pretty = true) shouldBe + """ + |{ + | "version":"2.1.0", + | "$schema":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/errata01/os/schemas/sarif-schema-2.1.0.json", + | "runs":[ + | { + | "tool":{ + | "driver":{ + | "organization":"Joern.io", + | "name":"Joern", + | "informationUri":"https://joern.io", + | "fullName":"Joern - The Bug Hunter's Workbench", + | "rules":[ + | { + | "id":"f1", + | "name":"", + | "fullDescription":{ + | "text":"something bad happened" + | } + | } + | ] + | } + | }, + | "results":[ + | { + | "locations":[ + | { + | "physicalLocation":{ + | "artifactLocation":{ + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | ], + | "relatedLocations":[ + | { + | "physicalLocation":{ + | "artifactLocation":{ + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | ], + | "message":{ + | "text":"" + | }, + | "codeFlows":[ + | { + | "threadFlows":[ + | { + | "locations":[ + | { + | "location":{ + | "physicalLocation":{ + | "artifactLocation":{ + | "uriBaseId":"PROJECT_ROOT" + | }, + | "region":{ + | "startLine":2, + | "snippet":{ + | "text":"public foo()" + | } + | } + | } + | } + | } + | ] + | } + | ] + | } + | ], + | "ruleId":"f1", + | "level":"warning" + | } + | ], + | "originalUriBaseIds":{ + | "PROJECT_ROOT":{ + | "uriBaseId":"" + | } + | } + | } + | ] + |} + |""".stripMargin.trim + } + + } + +} + +object SarifTests { + + def createValidFindingNode(cpg: Cpg): Unit = { + val dg = Cpg.newDiffGraphBuilder + val method = NewMethod() + .name("Foo") + .lineNumber(2) + .filename("Bar.java") + .code("public foo()") + val finding = NewFinding() + .evidence(Iterator.single(method)) + .keyValuePairs( + List( + NewKeyValuePair().key("name").value("f1"), + NewKeyValuePair().key("title").value("Rule 1"), + NewKeyValuePair().key("description").value("something bad happened"), + NewKeyValuePair().key("score").value("8.0") + ) + ) + dg.addNode(method) + .addNode(finding) + + DiffGraphApplier.applyDiff(cpg.graph, dg) + } + + def createInvalidFindingNode(cpg: Cpg): Unit = { + val dg = Cpg.newDiffGraphBuilder + val method = NewMethod() + .name("Foo") + .lineNumber(2) + .code("public foo()") + .filename("not compliant uri") + val finding = NewFinding() + .evidence(Iterator.single(method)) + .keyValuePairs( + List( + NewKeyValuePair().key("name").value("f1"), + NewKeyValuePair().key("description").value("something bad happened") + ) + ) + dg.addNode(method) + .addNode(finding) + + DiffGraphApplier.applyDiff(cpg.graph, dg) + } + +} diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/StepsTest.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/StepsTest.scala index 8e9107fe4412..1d276d2341c6 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/StepsTest.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/StepsTest.scala @@ -1,18 +1,15 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.Cpg.docSearchPackages import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{NodeTypes, Properties} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg +import flatgraph.help.Table.{AvailableWidthProvider, ConstantWidth} import org.json4s.* import org.json4s.native.JsonMethods.parse import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import overflowdb.traversal.help.Table.{AvailableWidthProvider, ConstantWidth} - -import java.util.Optional -import scala.jdk.CollectionConverters.IteratorHasAsScala class StepsTest extends AnyWordSpec with Matchers { @@ -52,7 +49,7 @@ class StepsTest extends AnyWordSpec with Matchers { val method: Method = cpg.method.head val results: List[Method] = cpg.method.id(method.id).toList results.size shouldBe 1 - results.head.underlying.id + results.head.id } "providing multiple" in { @@ -123,6 +120,9 @@ class StepsTest extends AnyWordSpec with Matchers { val parsed = parse(json).children.head // exactly one result for the above query (parsed \ "_label") shouldBe JString("METHOD") (parsed \ "name") shouldBe JString("foo") + + // id should be defined, but we don't care what number it is + (parsed \ "_id") shouldBe a[JInt] } "operating on NewNode" in { @@ -158,7 +158,7 @@ class StepsTest extends AnyWordSpec with Matchers { val nodeId = mainMethods.head.id val printed = mainMethods.p.head printed.should(startWith(s"""(METHOD,$nodeId):""")) - printed.should(include("IS_EXTERNAL: false")) + printed.should(include("SIGNATURE: asignature")) printed.should(include("FULL_NAME: woo")) } @@ -197,19 +197,19 @@ class StepsTest extends AnyWordSpec with Matchers { "show domain overview" in { val domainStartersHelp = Cpg.empty.help domainStartersHelp should include(".comment") - domainStartersHelp should include("All comments in source-based CPGs") + domainStartersHelp should include("A source code comment") domainStartersHelp should include(".arithmetic") domainStartersHelp should include("All arithmetic operations") } "provide node-specific overview" in { val methodStepsHelp = Cpg.empty.method.help - methodStepsHelp should include("Available steps for Method") + methodStepsHelp should include("Available steps for `Method`") methodStepsHelp should include(".namespace") methodStepsHelp should include(".depth") // from AstNode val methodStepsHelpVerbose = Cpg.empty.method.helpVerbose - methodStepsHelpVerbose should include("traversal name") + methodStepsHelpVerbose should include("implemented in") methodStepsHelpVerbose should include("structure.MethodTraversal") val assignmentStepsHelp = Cpg.empty.assignment.help @@ -283,7 +283,6 @@ class StepsTest extends AnyWordSpec with Matchers { def methodParameterOut = cpg.graph .nodes(NodeTypes.METHOD_PARAMETER_OUT) - .asScala .cast[MethodParameterOut] .name("param1") methodParameterOut.typ.name.head shouldBe "paramtype" @@ -305,7 +304,7 @@ class StepsTest extends AnyWordSpec with Matchers { file.typeDecl.name.head shouldBe "AClass" file.head.typeDecl.name.head shouldBe "AClass" - def block = cpg.graph.nodes(NodeTypes.BLOCK).asScala.cast[Block].typeFullName("int") + def block = cpg.graph.nodes(NodeTypes.BLOCK).cast[Block].typeFullName("int") block.local.name.size shouldBe 1 block.flatMap(_.local.name).size shouldBe 1 @@ -349,13 +348,6 @@ class StepsTest extends AnyWordSpec with Matchers { method.head.modifier.modifierType.toSetMutable shouldBe Set("modifiertype") } - "id starter step" in { - // only verifying what compiles and what doesn't... - // if it compiles, :shipit: - assertCompiles("cpg.id(1).out") - assertDoesNotCompile("cpg.id(1).outV") // `.outV` is only available on Traversal[Edge] - } - "property accessors" in { val cpg = MockCpg().withCustom { (diffGraph, _) => diffGraph @@ -370,21 +362,35 @@ class StepsTest extends AnyWordSpec with Matchers { val (Seq(emptyCall), Seq(callWithProperties)) = cpg.call.l.partition(_.argumentName.isEmpty) - emptyCall.propertyOption(Properties.TypeFullName) shouldBe Optional.of("") - emptyCall.propertyOption(Properties.TypeFullName.name) shouldBe Optional.of("") - emptyCall.propertyOption(Properties.ArgumentName) shouldBe Optional.empty - emptyCall.propertyOption(Properties.ArgumentName.name) shouldBe Optional.empty + // Cardinality.One + emptyCall.property(Properties.TypeFullName) shouldBe "" + emptyCall.propertyOption(Properties.TypeFullName) shouldBe Some("") + emptyCall.propertyOption(Properties.TypeFullName.name) shouldBe Some("") + // Cardinality.ZeroOrOne + emptyCall.property(Properties.ArgumentName) shouldBe None + emptyCall.propertyOption(Properties.ArgumentName) shouldBe None + emptyCall.propertyOption(Properties.ArgumentName.name) shouldBe None + // Cardinality.List // these ones are rather a historic accident it'd be better and more consistent to return `None` here - // we'll defer that change until after the flatgraph port though and just document it for now - emptyCall.propertyOption(Properties.DynamicTypeHintFullName) shouldBe Optional.of(Seq.empty) - emptyCall.propertyOption(Properties.DynamicTypeHintFullName.name) shouldBe Optional.of(Seq.empty) - - callWithProperties.propertyOption(Properties.TypeFullName) shouldBe Optional.of("aa") - callWithProperties.propertyOption(Properties.TypeFullName.name) shouldBe Optional.of("aa") - callWithProperties.propertyOption(Properties.ArgumentName) shouldBe Optional.of("bb") - callWithProperties.propertyOption(Properties.ArgumentName.name) shouldBe Optional.of("bb") - callWithProperties.propertyOption(Properties.DynamicTypeHintFullName) shouldBe Optional.of(Seq("cc", "dd")) - callWithProperties.propertyOption(Properties.DynamicTypeHintFullName.name) shouldBe Optional.of(Seq("cc", "dd")) + emptyCall.property(Properties.DynamicTypeHintFullName) shouldBe Seq.empty + emptyCall.propertyOption(Properties.DynamicTypeHintFullName) shouldBe Some(Seq.empty) + emptyCall.propertyOption(Properties.DynamicTypeHintFullName.name) shouldBe Some(Seq.empty) + + // Cardinality.One + callWithProperties.property(Properties.TypeFullName) shouldBe "aa" + callWithProperties.propertyOption(Properties.TypeFullName) shouldBe Some("aa") + callWithProperties.propertyOption(Properties.TypeFullName.name) shouldBe Some("aa") + + // Cardinality.ZeroOrOne + callWithProperties.property(Properties.ArgumentName) shouldBe Some("bb") + callWithProperties.propertyOption(Properties.ArgumentName) shouldBe Some("bb") + callWithProperties.propertyOption(Properties.ArgumentName.name) shouldBe Some("bb") + + // Cardinality.List + callWithProperties.property(Properties.DynamicTypeHintFullName) shouldBe Seq("cc", "dd") + callWithProperties.propertyOption(Properties.DynamicTypeHintFullName) shouldBe Some(Seq("cc", "dd")) + callWithProperties.propertyOption(Properties.DynamicTypeHintFullName.name) shouldBe Some(Seq("cc", "dd")) } } diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/accesspath/AccessPathTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/accesspath/AccessPathTests.scala similarity index 98% rename from semanticcpg/src/test/scala/io/shiftleft/semanticcpg/accesspath/AccessPathTests.scala rename to semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/accesspath/AccessPathTests.scala index 3cb15128a159..8a191132d0a8 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/accesspath/AccessPathTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/accesspath/AccessPathTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.accesspath -import io.shiftleft.semanticcpg.accesspath.MatchResult._ -import org.scalatest.matchers.should.Matchers._ +import io.shiftleft.semanticcpg.accesspath.MatchResult.* +import org.scalatest.matchers.should.Matchers.* import org.scalatest.wordspec.AnyWordSpec class AccessPathTests extends AnyWordSpec { diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/bindingextension/BindingTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/bindingextension/BindingTests.scala index 76db43185dbb..195e5472f2d7 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/bindingextension/BindingTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/bindingextension/BindingTests.scala @@ -1,8 +1,8 @@ package io.shiftleft.semanticcpg.language.bindingextension import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/operatorextension/OperatorExtensionTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/operatorextension/OperatorExtensionTests.scala index fe3c0938d0a6..650914868147 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/operatorextension/OperatorExtensionTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/operatorextension/OperatorExtensionTests.scala @@ -3,8 +3,7 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.Identifier -import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.ArrayAccess +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversalTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversalTests.scala index 361bc45d222e..aedde1cb4dad 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversalTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversalTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversalTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversalTests.scala index 1ec5f62bd598..0a9d459c5b48 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversalTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversalTests.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/FileTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/FileTests.scala index 3c9a9728c270..c750f85b538a 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/FileTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/FileTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes.{File, Namespace, TypeDecl} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.LoneElement import org.scalatest.matchers.should.Matchers diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTests.scala index a12a3a7edb2e..3972b43d70b5 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTests.scala index d98aa68f5811..5694017e9942 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes.{Method, MethodParameterIn} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTests.scala index 908a6a3d8a16..c80baace6241 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTests.scala @@ -2,7 +2,7 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{CfgNode, Expression, Literal, Method, NamespaceBlock, TypeDecl} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTests.scala index 184c015a0ceb..2e29239b6f0f 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTests.scala @@ -1,7 +1,7 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.testing.MockCpg import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTests.scala index 19442132e951..dedad804af74 100644 --- a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTests.scala +++ b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTests.scala @@ -70,8 +70,6 @@ class TypeTests extends AnyWordSpec with Matchers { .name(".*Base") .toList - cpg.typeDecl.name(".*Derived").baseTypeDecl.foreach(println) - queryResult.size shouldBe 1 } diff --git a/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/utils/CountStatementsTests.scala b/semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/utils/CountStatementsTests.scala similarity index 100% rename from semanticcpg/src/test/scala/io/shiftleft/semanticcpg/utils/CountStatementsTests.scala rename to semanticcpg/src/test/scala/io/shiftleft/semanticcpg/language/utils/CountStatementsTests.scala diff --git a/test-main.sc b/test-main.sc deleted file mode 100644 index 9fe313e18910..000000000000 --- a/test-main.sc +++ /dev/null @@ -1,3 +0,0 @@ -@main def main() = { - println("Hello, world!") -} diff --git a/test-simple.sc b/test-simple.sc deleted file mode 100644 index c4267130d162..000000000000 --- a/test-simple.sc +++ /dev/null @@ -1 +0,0 @@ -println("Hello!") diff --git a/tests/code/javasrc/SliceTest.java b/tests/code/javasrc/SliceTest.java new file mode 100644 index 000000000000..0d2b2dd099fa --- /dev/null +++ b/tests/code/javasrc/SliceTest.java @@ -0,0 +1,16 @@ + + +public class SliceTest { + + public void foo(boolean b) { + String s = new Foo("MALICIOUS"); + if (b) { + s.setFoo("SAFE"); + } + bar(b); + } + + public void bar(String x) { + System.out.println(s); + } +} \ No newline at end of file diff --git a/tests/code/sarif-test/main.c b/tests/code/sarif-test/main.c new file mode 100644 index 000000000000..78052dddf47b --- /dev/null +++ b/tests/code/sarif-test/main.c @@ -0,0 +1,10 @@ +int index_into_dst_array (char *dst, char *src, int offset) { + for(i = 0; i < strlen(src); i++) { + dst[i + + j*8 + offset] = src[i]; + } +} + +int vulnerable(size_t len, char *src) { + char *dst = malloc(len + 8); + memcpy(dst, src, len + 7); +} diff --git a/tests/finding-to-sarif-test.sh b/tests/finding-to-sarif-test.sh new file mode 100755 index 000000000000..5029be4dd5bb --- /dev/null +++ b/tests/finding-to-sarif-test.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_ABS_PATH=$(readlink -f "$0") +JOERN_TESTS_DIR=$(dirname "$SCRIPT_ABS_PATH") +JOERN="$JOERN_TESTS_DIR"/.. + +mkdir -p /tmp/sarif +./joern-scan "$JOERN_TESTS_DIR/code/sarif-test" --store +./joern --script "$JOERN_TESTS_DIR/test-sarif.sc" --param cpgFile="$JOERN/workspace/sarif-test/cpg.bin" --param outFile="/tmp/sarif/test.sarif" +exit_code=$(curl -s -X POST \ + -F "postedFiles=@/tmp/sarif/test.sarif;type=application/octet-stream" \ + https://sarifweb.azurewebsites.net/Validation/ValidateFiles | jq -r '.exitCode') + +echo "SARIF Validation Exit Code: $exit_code" + +exit $exit_code diff --git a/test-additionalfuncs.sc b/tests/test-additionalfuncs.sc similarity index 100% rename from test-additionalfuncs.sc rename to tests/test-additionalfuncs.sc diff --git a/test-cpg-callotherscript.sc b/tests/test-cpg-callotherscript.sc similarity index 100% rename from test-cpg-callotherscript.sc rename to tests/test-cpg-callotherscript.sc diff --git a/test-cpg.sc b/tests/test-cpg.sc similarity index 100% rename from test-cpg.sc rename to tests/test-cpg.sc diff --git a/tests/test-dataflow-slice.sc b/tests/test-dataflow-slice.sc new file mode 100644 index 000000000000..dfeeb9b23cb9 --- /dev/null +++ b/tests/test-dataflow-slice.sc @@ -0,0 +1,17 @@ +import upickle.default.* +import io.shiftleft.utils.IOUtils +import java.nio.file.Path +import io.joern.dataflowengineoss.slicing.{DataFlowSlice, SliceEdge} + +@main def exec(sliceFile: String) = { + val jsonContent = IOUtils.readLinesInFile(Path.of(sliceFile)).mkString + val dataFlowSlice = read[DataFlowSlice](jsonContent) + val nodeMap = dataFlowSlice.nodes.map(n => n.id -> n).toMap + val edges = dataFlowSlice.edges.toList + .map { case SliceEdge(src, dst, _) => + (nodeMap(src).lineNumber, nodeMap(dst).lineNumber) -> List(nodeMap(src).code, nodeMap(dst).code).distinct + } + .sortBy(_._1) + .flatMap(_._2) + println(edges) +} diff --git a/test-dependencies.sc b/tests/test-dependencies.sc similarity index 100% rename from test-dependencies.sc rename to tests/test-dependencies.sc diff --git a/test-main-withargs.sc b/tests/test-main-withargs.sc similarity index 100% rename from test-main-withargs.sc rename to tests/test-main-withargs.sc diff --git a/tests/test-main.sc b/tests/test-main.sc new file mode 100644 index 000000000000..b60b396e723d --- /dev/null +++ b/tests/test-main.sc @@ -0,0 +1,5 @@ +@main def main() = { + println("Hello, world!") + println(help) // should work, comes from joern commands + // val i: Int = "foo" // to test line number reporting +} diff --git a/tests/test-sarif.sc b/tests/test-sarif.sc new file mode 100644 index 000000000000..83f23572dfae --- /dev/null +++ b/tests/test-sarif.sc @@ -0,0 +1,8 @@ +// to test, run e.g. +// ./joern --script test-sarif.sc --param cpgFile=workspace/src/cpg.bin --param outFile=test.sarif + +@main def exec(cpgFile: String, outFile: String) = { + importCpg(cpgFile) + assert(cpg.finding.nonEmpty, "no findings in this cpg - please check the setup") + cpg.finding.toSarifJson() |> outFile +} diff --git a/tests/test-simple.sc b/tests/test-simple.sc new file mode 100644 index 000000000000..3984396eeab2 --- /dev/null +++ b/tests/test-simple.sc @@ -0,0 +1,3 @@ +println("Hello, world2!") +println(help) // should work, comes from joern commands +// val i: Int = "foo" // to test line number reporting