Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 277 additions & 1 deletion cmd/cg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package main
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
Expand All @@ -26,18 +29,20 @@ import (
)

func main() {
// Note: sfdp is checked conditionally when -graph flag is used
tools := []string{"go", "digraph", "git"}
if !utils.ValidateTools(tools) {
os.Exit(1)
}

// Define flags
var fix = flag.Bool("fix", false, "run fix commands after analysis")
var algo = flag.String("algo", "vta", "call graph algorithm: vta (default), cha, rta, static")
var algo = flag.String("algo", "rta", "call graph algorithm: rta (default), cha, vta, static")
var progress = flag.Bool("progress", false, "show progress of completed and pending jobs")
var library = flag.String("library", "", "override library path to scan (e.g., golang.org/x/net/html)")
var symbols = flag.String("symbols", "", "override symbol(s) to scan for (comma-separated, e.g., Parse,Render)")
var fixversion = flag.String("fixversion", "", "fixed version for manual scans (e.g., v1.9.4)")
var graph = flag.String("graph", "", "generate call graph SVG visualization (default: ./site/callgraph.svg if flag used without value)")

// Custom usage function
flag.Usage = func() {
Expand All @@ -49,6 +54,12 @@ func main() {
fmt.Fprintf(os.Stderr, " When -library and -symbols are provided, they take precedence over CVE-based symbol lookup.\n")
fmt.Fprintf(os.Stderr, " Optionally use -fixversion to specify the fixed version for version comparison.\n")
fmt.Fprintf(os.Stderr, " This allows scanning for specific library/symbol combinations directly.\n")
fmt.Fprintf(os.Stderr, "\nCall Graph Visualization:\n")
fmt.Fprintf(os.Stderr, " Use -graph to generate an SVG visualization of the call graph.\n")
fmt.Fprintf(os.Stderr, " -graph : Saves to ./site/callgraph-<random>.svg (default)\n")
fmt.Fprintf(os.Stderr, " -graph=./path : Saves to ./path/callgraph-<random>.svg (directory)\n")
fmt.Fprintf(os.Stderr, " -graph=file.svg : Saves to file.svg (specific file)\n")
fmt.Fprintf(os.Stderr, " Requires: graphviz (sfdp tool) to be installed\n")
}

// Parse flags
Expand Down Expand Up @@ -349,6 +360,89 @@ func main() {
result.Errors = fixResult.Errors
}

// Generate call graph visualizations if requested (one per affected symbol)
if *graph != "" || isFlagPassed("graph") {
// Validate that sfdp is available
if !utils.ValidateTools([]string{"sfdp"}) {
errMsg := "sfdp tool not found. Please install graphviz (provides sfdp) to generate call graph visualizations"
result.Errors = append(result.Errors, errMsg)
if *progress {
fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg)
}
} else if result.IsVulnerable == "true" && len(result.UsedImports) > 0 {
// Only generate graphs if vulnerable
if *progress {
fmt.Fprintf(os.Stderr, "Generating call graph visualizations for affected symbols...\n")
}

outputDir := *graph
if outputDir == "" {
// Default to ./site/ directory
outputDir = "./site"
}

// Ensure output directory exists
if info, err := os.Stat(outputDir); err != nil {
// Directory doesn't exist, create it
if err := os.MkdirAll(outputDir, 0755); err != nil {
errMsg := fmt.Sprintf("Failed to create directory %s: %v", outputDir, err)
result.Errors = append(result.Errors, errMsg)
if *progress {
fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg)
}
}
} else if !info.IsDir() {
// Path exists but is not a directory
errMsg := fmt.Sprintf("Output path %s exists but is not a directory", outputDir)
result.Errors = append(result.Errors, errMsg)
if *progress {
fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg)
}
}

// Generate a graph for each vulnerable symbol
result.GraphPaths = []string{}
for pkg, details := range result.UsedImports {
for _, symbol := range details.Symbols {
// Create filename: library-symbol.svg
// Sanitize names for filesystem (CVE is in the directory structure when using -graph with server)
sanitizedLib := strings.ReplaceAll(pkg, "/", "-")
sanitizedSymbol := strings.ReplaceAll(symbol, ".", "-")
sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, "*", "ptr")
sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, "(", "")
sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, ")", "")

filename := fmt.Sprintf("%s-%s.svg", sanitizedLib, sanitizedSymbol)
outputPath := filepath.Join(outputDir, filename)

if *progress {
fmt.Fprintf(os.Stderr, " Generating graph for %s.%s...\n", pkg, symbol)
}

svgPath, err := generateCallGraphSVGForSymbol(result, directory, pkg, symbol, outputPath, *progress)
if err != nil {
errMsg := fmt.Sprintf("Failed to generate call graph for %s.%s: %v", pkg, symbol, err)
result.Errors = append(result.Errors, errMsg)
if *progress {
fmt.Fprintf(os.Stderr, " ✗ Failed: %v\n", err)
}
} else {
if *progress {
fmt.Fprintf(os.Stderr, " ✓ Saved to: %s\n", svgPath)
}
result.GraphPaths = append(result.GraphPaths, svgPath)
}
}
}

if *progress && len(result.GraphPaths) > 0 {
fmt.Fprintf(os.Stderr, "✓ Generated %d call graph visualization(s)\n", len(result.GraphPaths))
}
} else if *progress {
fmt.Fprintf(os.Stderr, "Skipping graph generation: No vulnerabilities found\n")
}
}

// Generate summary and output JSON
cg.GenerateSummaryWithGemini(result)
jsonOutput, err := json.MarshalIndent(result, "", " ")
Expand All @@ -360,6 +454,188 @@ func main() {
}
}

// isFlagPassed checks if a flag was explicitly set on the command line
func isFlagPassed(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}

// generateRandomFilename creates a random 8-character hex string
func generateRandomFilename() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
// Fallback to timestamp if random fails
return fmt.Sprintf("%d", time.Now().Unix())
}
return hex.EncodeToString(bytes)
}

// generateCallGraphSVGForSymbol generates an SVG visualization of the call graph for a specific symbol
func generateCallGraphSVGForSymbol(result *cg.Result, directory, pkg, symbol, outputPath string, showProgress bool) (string, error) {
// Get the first main file set to generate the call graph
var files []string
var modDir string

for dir, sets := range result.Files {
if len(sets) > 0 && len(sets[0]) > 0 {
files = sets[0]
modDir = dir
break
}
}

if len(files) == 0 {
return "", fmt.Errorf("no main files found for call graph generation")
}

fullModDir := filepath.Join(directory, modDir)

// Generate call graph using the scanner's function
tempResult := &cg.Result{
Directory: directory,
Errors: []string{},
}

if showProgress {
fmt.Fprintf(os.Stderr, " Building call graph from %s...\n", fullModDir)
}

callGraphOutput, err := tempResult.GenerateCallGraphForVisualization(fullModDir, files)
if err != nil {
return "", fmt.Errorf("failed to generate call graph: %v", err)
}

// Filter the call graph to only include paths to the target symbol
if showProgress {
fmt.Fprintf(os.Stderr, " Filtering graph for symbol %s...\n", symbol)
}

filteredGraph, err := filterCallGraphForSymbol(callGraphOutput, pkg, symbol)
if err != nil {
return "", fmt.Errorf("failed to filter call graph: %v", err)
}

if filteredGraph == "" {
return "", fmt.Errorf("no paths found to symbol %s.%s", pkg, symbol)
}

if showProgress {
fmt.Fprintf(os.Stderr, " Converting to DOT format...\n")
}

// Convert call graph to DOT format using digraph
dotCmd := exec.Command("digraph", "to", "dot")
dotCmd.Stdin = strings.NewReader(filteredGraph)
dotOutput, err := dotCmd.Output()
if err != nil {
return "", fmt.Errorf("failed to convert to DOT format: %v", err)
}

if showProgress {
fmt.Fprintf(os.Stderr, " Rendering SVG with sfdp layout...\n")
}

// Convert DOT to SVG using sfdp
sfdpCmd := exec.Command("sfdp", "-Tsvg", "-Goverlap=scale")
sfdpCmd.Stdin = bytes.NewReader(dotOutput)
svgOutput, err := sfdpCmd.Output()
if err != nil {
return "", fmt.Errorf("failed to render SVG: %v", err)
}

// Write SVG to file
if err := os.WriteFile(outputPath, svgOutput, 0644); err != nil {
return "", fmt.Errorf("failed to write SVG file: %v", err)
}

return outputPath, nil
}

// filterCallGraphForSymbol filters the call graph to only include nodes relevant to reaching the target symbol
func filterCallGraphForSymbol(callGraphOutput, pkg, symbol string) (string, error) {
// Build a set of all nodes that can reach the target symbol
// Use digraph to find all paths from main to the symbol
scanner := bufio.NewScanner(strings.NewReader(callGraphOutput))
var entryPoints []string
seen := make(map[string]bool)

// Find all entry points (main functions)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) >= 1 {
caller := fields[0]
if strings.HasSuffix(caller, ".main") && !seen[caller] {
entryPoints = append(entryPoints, caller)
seen[caller] = true
}
}
}

if len(entryPoints) == 0 {
return "", fmt.Errorf("no entry points found")
}

// Find all paths from entry points to the target symbol
reachableNodes := make(map[string]bool)

// Try different symbol patterns
symbolPatterns := []string{
fmt.Sprintf("%s.%s", pkg, symbol),
fmt.Sprintf("(%s.%s)", pkg, symbol),
fmt.Sprintf("(*%s.%s)", pkg, symbol),
}

for _, entryPoint := range entryPoints {
for _, symPattern := range symbolPatterns {
// Use digraph allpaths to find all paths from entry point to symbol
cmd := exec.Command("digraph", "allpaths", entryPoint, symPattern)
cmd.Stdin = strings.NewReader(callGraphOutput)
output, err := cmd.Output()

if err == nil && len(bytes.TrimSpace(output)) > 0 {
// Parse the paths and add all nodes to reachableNodes
pathScanner := bufio.NewScanner(bytes.NewReader(output))
for pathScanner.Scan() {
line := pathScanner.Text()
fields := strings.Fields(line)
if len(fields) >= 2 {
reachableNodes[fields[0]] = true
reachableNodes[fields[1]] = true
}
}
}
}
}

if len(reachableNodes) == 0 {
return "", fmt.Errorf("no paths found to symbol")
}

// Build filtered graph with only reachable nodes
var filteredLines []string
scanner = bufio.NewScanner(strings.NewReader(callGraphOutput))
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) >= 2 {
caller := fields[0]
callee := fields[1]
// Include edge if both nodes are reachable
if reachableNodes[caller] && reachableNodes[callee] {
filteredLines = append(filteredLines, line)
}
}
}

return strings.Join(filteredLines, "\n"), nil
}

// initResultWithProgress initializes the result with progress tracking for each phase
func initResultWithProgress(cve, dir string, fix bool, library, symbols, fixversion string) *cg.Result {
r := &cg.Result{
Expand Down
2 changes: 1 addition & 1 deletion cmd/gvs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
log.Fatalf("Failed to create directory: %v", err)
}

http.Handle("/cg/", gvs.LogFileAccess(http.StripPrefix("/cg/", http.FileServer(http.Dir("/tmp/gvs-cache/img")))))
http.Handle("/graph/", gvs.LogFileAccess(http.StripPrefix("/graph/", http.FileServer(http.Dir("/tmp/gvs-cache/graph")))))
http.Handle("/", http.FileServer(http.Dir("./site")))
http.HandleFunc("/scan", api.CORSMiddleware(api.ScanHandler))
http.HandleFunc("/healthz", api.CORSMiddleware(api.HealthHandler))
Expand Down
Loading
Loading