diff --git a/cmd/cg/main.go b/cmd/cg/main.go index 86b2f9a..124a05d 100644 --- a/cmd/cg/main.go +++ b/cmd/cg/main.go @@ -3,12 +3,15 @@ package main import ( "bufio" "bytes" + "crypto/rand" + "encoding/hex" "encoding/json" "flag" "fmt" "io" "net/http" "os" + "os/exec" "path/filepath" "runtime" "slices" @@ -26,6 +29,7 @@ import ( ) func main() { + // Note: sfdp is checked conditionally when -graph flag is used tools := []string{"go", "digraph", "git"} if !utils.ValidateTools(tools) { os.Exit(1) @@ -33,11 +37,12 @@ func main() { // Define flags var fix = flag.Bool("fix", false, "run fix commands after analysis") - var algo = flag.String("algo", "vta", "call graph algorithm: vta (default), cha, rta, static") + var algo = flag.String("algo", "rta", "call graph algorithm: rta (default), cha, vta, static") var progress = flag.Bool("progress", false, "show progress of completed and pending jobs") var library = flag.String("library", "", "override library path to scan (e.g., golang.org/x/net/html)") var symbols = flag.String("symbols", "", "override symbol(s) to scan for (comma-separated, e.g., Parse,Render)") var fixversion = flag.String("fixversion", "", "fixed version for manual scans (e.g., v1.9.4)") + var graph = flag.String("graph", "", "generate call graph SVG visualization (default: ./site/callgraph.svg if flag used without value)") // Custom usage function flag.Usage = func() { @@ -49,6 +54,12 @@ func main() { fmt.Fprintf(os.Stderr, " When -library and -symbols are provided, they take precedence over CVE-based symbol lookup.\n") fmt.Fprintf(os.Stderr, " Optionally use -fixversion to specify the fixed version for version comparison.\n") fmt.Fprintf(os.Stderr, " This allows scanning for specific library/symbol combinations directly.\n") + fmt.Fprintf(os.Stderr, "\nCall Graph Visualization:\n") + fmt.Fprintf(os.Stderr, " Use -graph to generate an SVG visualization of the call graph.\n") + fmt.Fprintf(os.Stderr, " -graph : Saves to ./site/callgraph-.svg (default)\n") + fmt.Fprintf(os.Stderr, " -graph=./path : Saves to ./path/callgraph-.svg (directory)\n") + fmt.Fprintf(os.Stderr, " -graph=file.svg : Saves to file.svg (specific file)\n") + fmt.Fprintf(os.Stderr, " Requires: graphviz (sfdp tool) to be installed\n") } // Parse flags @@ -349,6 +360,89 @@ func main() { result.Errors = fixResult.Errors } + // Generate call graph visualizations if requested (one per affected symbol) + if *graph != "" || isFlagPassed("graph") { + // Validate that sfdp is available + if !utils.ValidateTools([]string{"sfdp"}) { + errMsg := "sfdp tool not found. Please install graphviz (provides sfdp) to generate call graph visualizations" + result.Errors = append(result.Errors, errMsg) + if *progress { + fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg) + } + } else if result.IsVulnerable == "true" && len(result.UsedImports) > 0 { + // Only generate graphs if vulnerable + if *progress { + fmt.Fprintf(os.Stderr, "Generating call graph visualizations for affected symbols...\n") + } + + outputDir := *graph + if outputDir == "" { + // Default to ./site/ directory + outputDir = "./site" + } + + // Ensure output directory exists + if info, err := os.Stat(outputDir); err != nil { + // Directory doesn't exist, create it + if err := os.MkdirAll(outputDir, 0755); err != nil { + errMsg := fmt.Sprintf("Failed to create directory %s: %v", outputDir, err) + result.Errors = append(result.Errors, errMsg) + if *progress { + fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg) + } + } + } else if !info.IsDir() { + // Path exists but is not a directory + errMsg := fmt.Sprintf("Output path %s exists but is not a directory", outputDir) + result.Errors = append(result.Errors, errMsg) + if *progress { + fmt.Fprintf(os.Stderr, "✗ %s\n", errMsg) + } + } + + // Generate a graph for each vulnerable symbol + result.GraphPaths = []string{} + for pkg, details := range result.UsedImports { + for _, symbol := range details.Symbols { + // Create filename: library-symbol.svg + // Sanitize names for filesystem (CVE is in the directory structure when using -graph with server) + sanitizedLib := strings.ReplaceAll(pkg, "/", "-") + sanitizedSymbol := strings.ReplaceAll(symbol, ".", "-") + sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, "*", "ptr") + sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, "(", "") + sanitizedSymbol = strings.ReplaceAll(sanitizedSymbol, ")", "") + + filename := fmt.Sprintf("%s-%s.svg", sanitizedLib, sanitizedSymbol) + outputPath := filepath.Join(outputDir, filename) + + if *progress { + fmt.Fprintf(os.Stderr, " Generating graph for %s.%s...\n", pkg, symbol) + } + + svgPath, err := generateCallGraphSVGForSymbol(result, directory, pkg, symbol, outputPath, *progress) + if err != nil { + errMsg := fmt.Sprintf("Failed to generate call graph for %s.%s: %v", pkg, symbol, err) + result.Errors = append(result.Errors, errMsg) + if *progress { + fmt.Fprintf(os.Stderr, " ✗ Failed: %v\n", err) + } + } else { + if *progress { + fmt.Fprintf(os.Stderr, " ✓ Saved to: %s\n", svgPath) + } + result.GraphPaths = append(result.GraphPaths, svgPath) + } + } + } + + if *progress && len(result.GraphPaths) > 0 { + fmt.Fprintf(os.Stderr, "✓ Generated %d call graph visualization(s)\n", len(result.GraphPaths)) + } + } else if *progress { + fmt.Fprintf(os.Stderr, "Skipping graph generation: No vulnerabilities found\n") + } + } + // Generate summary and output JSON cg.GenerateSummaryWithGemini(result) jsonOutput, err := json.MarshalIndent(result, "", " ") @@ -360,6 +454,188 @@ func main() { } } +// isFlagPassed checks if a flag was explicitly set on the command line +func isFlagPassed(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +// generateRandomFilename creates a random 8-character hex string +func generateRandomFilename() string { + bytes := make([]byte, 4) + if _, err := rand.Read(bytes); err != nil { + // Fallback to timestamp if random fails + return fmt.Sprintf("%d", time.Now().Unix()) + } + return hex.EncodeToString(bytes) +} + +// generateCallGraphSVGForSymbol generates an SVG visualization of the call graph for a specific symbol +func generateCallGraphSVGForSymbol(result *cg.Result, directory, pkg, symbol, outputPath string, showProgress bool) (string, error) { + // Get the first main file set to generate the call graph + var files []string + var modDir string + + for dir, sets := range result.Files { + if len(sets) > 0 && len(sets[0]) > 0 { + files = sets[0] + modDir = dir + break + } + } + + if len(files) == 0 { + return "", fmt.Errorf("no main files found for call graph generation") + } + + fullModDir := filepath.Join(directory, modDir) + + // Generate call graph using the scanner's function + tempResult := &cg.Result{ + Directory: directory, + Errors: []string{}, + } + + if showProgress { + fmt.Fprintf(os.Stderr, " Building call graph from %s...\n", fullModDir) + } + + callGraphOutput, err := tempResult.GenerateCallGraphForVisualization(fullModDir, files) + if err != nil { + return "", fmt.Errorf("failed to generate call graph: %v", err) + } + + // Filter the call graph to only include paths to the target symbol + if showProgress { + fmt.Fprintf(os.Stderr, " Filtering graph for symbol %s...\n", symbol) + } + + filteredGraph, err := filterCallGraphForSymbol(callGraphOutput, pkg, symbol) + if err != nil { + return "", fmt.Errorf("failed to filter call graph: %v", err) + } + + if filteredGraph == "" { + return "", fmt.Errorf("no paths found to symbol %s.%s", pkg, symbol) + } + + if showProgress { + fmt.Fprintf(os.Stderr, " Converting to DOT format...\n") + } + + // Convert call graph to DOT format using digraph + dotCmd := exec.Command("digraph", "to", "dot") + dotCmd.Stdin = strings.NewReader(filteredGraph) + dotOutput, err := dotCmd.Output() + if err != nil { + return "", fmt.Errorf("failed to convert to DOT format: %v", err) + } + + if showProgress { + fmt.Fprintf(os.Stderr, " Rendering SVG with sfdp layout...\n") + } + + // Convert DOT to SVG using sfdp + sfdpCmd := exec.Command("sfdp", "-Tsvg", "-Goverlap=scale") + sfdpCmd.Stdin = bytes.NewReader(dotOutput) + svgOutput, err := sfdpCmd.Output() + if err != nil { + return "", fmt.Errorf("failed to render SVG: %v", err) + } + + // Write SVG to file + if err := os.WriteFile(outputPath, svgOutput, 0644); err != nil { + return "", fmt.Errorf("failed to write SVG file: %v", err) + } + + return outputPath, nil +} + +// filterCallGraphForSymbol filters the call graph to only include nodes relevant to reaching the target symbol +func filterCallGraphForSymbol(callGraphOutput, pkg, symbol string) (string, error) { + // Build a set of all nodes that can reach the target symbol + // Use digraph to find all paths from main to the symbol + scanner := bufio.NewScanner(strings.NewReader(callGraphOutput)) + var entryPoints []string + seen := make(map[string]bool) + + // Find all entry points (main functions) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 1 { + caller := fields[0] + if strings.HasSuffix(caller, ".main") && !seen[caller] { + entryPoints = append(entryPoints, caller) + seen[caller] = true + } + } + } + + if len(entryPoints) == 0 { + return "", fmt.Errorf("no entry points found") + } + + // Find all paths from entry points to the target symbol + reachableNodes := make(map[string]bool) + + // Try different symbol patterns + symbolPatterns := []string{ + fmt.Sprintf("%s.%s", pkg, symbol), + fmt.Sprintf("(%s.%s)", pkg, symbol), + fmt.Sprintf("(*%s.%s)", pkg, symbol), + } + + for _, entryPoint := range entryPoints { + for _, symPattern := range symbolPatterns { + // Use digraph allpaths to find all paths from entry point to symbol + cmd := exec.Command("digraph", "allpaths", entryPoint, symPattern) + cmd.Stdin = strings.NewReader(callGraphOutput) + output, err := cmd.Output() + + if err == nil && len(bytes.TrimSpace(output)) > 0 { + // Parse the paths and add all nodes to reachableNodes + pathScanner := bufio.NewScanner(bytes.NewReader(output)) + for pathScanner.Scan() { + line := pathScanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + reachableNodes[fields[0]] = true + reachableNodes[fields[1]] = true + } + } + } + } + } + + if len(reachableNodes) == 0 { + return "", fmt.Errorf("no paths found to symbol") + } + + // Build filtered graph with only reachable nodes + var filteredLines []string + scanner = bufio.NewScanner(strings.NewReader(callGraphOutput)) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + caller := fields[0] + callee := fields[1] + // Include edge if both nodes are reachable + if reachableNodes[caller] && reachableNodes[callee] { + filteredLines = append(filteredLines, line) + } + } + } + + return strings.Join(filteredLines, "\n"), nil +} + // initResultWithProgress initializes the result with progress tracking for each phase func initResultWithProgress(cve, dir string, fix bool, library, symbols, fixversion string) *cg.Result { r := &cg.Result{ diff --git a/cmd/gvs/main.go b/cmd/gvs/main.go index 3baef82..697de47 100644 --- a/cmd/gvs/main.go +++ b/cmd/gvs/main.go @@ -31,7 +31,7 @@ func main() { log.Fatalf("Failed to create directory: %v", err) } - http.Handle("/cg/", gvs.LogFileAccess(http.StripPrefix("/cg/", http.FileServer(http.Dir("/tmp/gvs-cache/img"))))) + http.Handle("/graph/", gvs.LogFileAccess(http.StripPrefix("/graph/", http.FileServer(http.Dir("/tmp/gvs-cache/graph"))))) http.Handle("/", http.FileServer(http.Dir("./site"))) http.HandleFunc("/scan", api.CORSMiddleware(api.ScanHandler)) http.HandleFunc("/healthz", api.CORSMiddleware(api.HealthHandler)) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index ec8604b..30782fd 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -258,7 +258,8 @@ func CallgraphHandler(w http.ResponseWriter, r *http.Request) { Library string `json:"library"` Symbol string `json:"symbol"` FixVersion string `json:"fixversion"` - Fix bool `json:"fix"` + Algo string `json:"algo"` + Graph bool `json:"graph"` ShowProgress bool `json:"showProgress"` } @@ -302,7 +303,18 @@ func CallgraphHandler(w http.ResponseWriter, r *http.Request) { progressStreams[taskId] = make(chan string, 100) progressMutex.Unlock() - go func(taskId, repo, branchOrCommit, cve, library, symbol, fixversion string, fix bool) { + // Get the base URL from the request for graph path conversion + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + // Check for X-Forwarded-Proto header (for reverse proxy setups) + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + scheme = proto + } + baseURL := fmt.Sprintf("%s://%s", scheme, r.Host) + + go func(taskId, repo, branchOrCommit, cve, library, symbol, fixversion, algo, baseURL string, graph bool) { defer func() { requestMutex.Lock() inProgress = false @@ -337,36 +349,14 @@ func CallgraphHandler(w http.ResponseWriter, r *http.Request) { updateStatus(StatusRunning, "", "") - // Include library, symbol, and fixversion in cache key if provided - cacheKey := fmt.Sprintf("%s@%s:%s:lib=%s:sym=%s:fixver=%s:fix=%t", repo, branchOrCommit, cve, library, symbol, fixversion, fix) + // Include library, symbol, fixversion, algo, and graph in cache key if provided + cacheKey := fmt.Sprintf("%s@%s:%s:lib=%s:sym=%s:fixver=%s:algo=%s:graph=%t", repo, branchOrCommit, cve, library, symbol, fixversion, algo, graph) if cachedData, err := RetrieveCacheFromDisk(cacheKey); err == nil { updateStatus(StatusCompleted, string(cachedData), "") log.Printf("[Task %s] Retrieved callgraph from cache", taskId) return } - // If fix=true and no direct cache, check for fix=false cache - // We can reuse the cached directory and execute fix commands from the cached data - if fix { - fallbackCacheKey := fmt.Sprintf("%s@%s:%s:lib=%s:sym=%s:fixver=%s:fix=false", repo, branchOrCommit, cve, library, symbol, fixversion) - if fallbackCachedData, err := RetrieveCacheFromDisk(fallbackCacheKey); err == nil { - // Parse the cached data, execute fix commands, and create fix=true response - start := time.Now() - log.Printf("[Task %s] Converting fix=false cache to fix=true cache and executing fixes ...", taskId) - if optimizedOutput, err := ConvertCacheForRunFix(fallbackCachedData); err == nil { - updateStatus(StatusCompleted, string(optimizedOutput), "") - log.Printf("[Task %s] Retrieved and executed fixes using fix=false cache - Took %s", taskId, time.Since(start)) - // Save the converted output to fix=true cache for future use - if err := SaveCacheToDisk(cacheKey, optimizedOutput); err != nil { - log.Printf("[Task %s] Failed to save converted cache: %v", taskId, err) - } - return - } else { - log.Printf("[Task %s] Failed to convert fallback cache: %v", taskId, err) - } - } - } - cloneDir, err := os.MkdirTemp("", "cg-"+path.Base(repo)+"-*") if err != nil { updateStatus(StatusFailed, "", fmt.Sprintf("failed to create temp dir: %v", err)) @@ -384,13 +374,37 @@ func CallgraphHandler(w http.ResponseWriter, r *http.Request) { sendProgress(fmt.Sprintf("Clone successful - Took %s", time.Since(start))) log.Printf("[Task %s] Running cg ...", taskId) - sendProgress("Running vulnerability analysis...") + // Display algorithm being used (default to rta if not specified) + algoName := algo + if algoName == "" { + algoName = "rta" + } + sendProgress(fmt.Sprintf("Running vulnerability analysis (algorithm: %s)...", algoName)) start = time.Now() var cmd *exec.Cmd // Build command arguments based on whether library/symbols are provided args := []string{"-progress"} - if fix { - args = append(args, "-fix") + if algo != "" { + args = append(args, fmt.Sprintf("-algo=%s", algo)) + } + if graph { + // Create subdirectory structure: /graph/CVE-XXXX/repo/branch/algo/ + // Sanitize repo name (extract just the repo name from URL) + repoName := path.Base(repo) + repoName = strings.TrimSuffix(repoName, ".git") + // Sanitize branch/commit for filesystem + sanitizedBranch := strings.ReplaceAll(branchOrCommit, "/", "-") + // Sanitize CVE for filesystem + sanitizedCVE := strings.ReplaceAll(cve, "/", "-") + if sanitizedCVE == "" { + sanitizedCVE = "unknown-cve" + } + + graphDir := filepath.Join("/tmp/gvs-cache/graph", sanitizedCVE, repoName, sanitizedBranch, algoName) + if err := os.MkdirAll(graphDir, 0755); err != nil { + log.Printf("[Task %s] Failed to create graph directory: %v", taskId, err) + } + args = append(args, fmt.Sprintf("-graph=%s", graphDir)) } if library != "" && symbol != "" { args = append(args, "-library", library, "-symbols", symbol) @@ -410,12 +424,18 @@ func CallgraphHandler(w http.ResponseWriter, r *http.Request) { } log.Printf("[Task %s] cg execution completed - Took %s", taskId, time.Since(start)) + + // If graph generation was requested, convert file paths to web-accessible URLs + if graph { + output = convertGraphPathsToURLs(output, baseURL) + } + updateStatus(StatusCompleted, string(output), "") if err := SaveCacheToDisk(cacheKey, output); err != nil { log.Printf("[Task %s] Failed to save cache: %v", taskId, err) } - }(taskId, req.Repo, req.BranchOrCommit, req.CVE, req.Library, req.Symbol, req.FixVersion, req.Fix) + }(taskId, req.Repo, req.BranchOrCommit, req.CVE, req.Library, req.Symbol, req.FixVersion, req.Algo, baseURL, req.Graph) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"taskId": taskId}); err != nil { @@ -574,3 +594,51 @@ func runGovulncheckWithProgress(directory, target string, sendProgress func(stri return output, exitCode, err } + +// convertGraphPathsToURLs converts file paths in GraphPaths to web-accessible URLs +// Paths are expected to be like: /tmp/gvs-cache/graph/CVE-XXXX/repo/branch/algo/library-symbol.svg +// Converts to: http://host:port/graph/CVE-XXXX/repo/branch/algo/library-symbol.svg +func convertGraphPathsToURLs(jsonOutput []byte, baseURL string) []byte { + var result map[string]interface{} + if err := json.Unmarshal(jsonOutput, &result); err != nil { + log.Printf("Failed to parse JSON for graph path conversion: %v", err) + return jsonOutput + } + + // Check if GraphPaths field exists + graphPaths, ok := result["GraphPaths"].([]interface{}) + if !ok || len(graphPaths) == 0 { + return jsonOutput + } + + // Convert file paths to full URLs + // Extract the path relative to /tmp/gvs-cache/graph and prefix with baseURL/graph/ + webPaths := make([]string, 0, len(graphPaths)) + const cacheDir = "/tmp/gvs-cache/graph/" + for _, p := range graphPaths { + if pathStr, ok := p.(string); ok { + // Extract the relative path after /tmp/gvs-cache/graph/ + if strings.HasPrefix(pathStr, cacheDir) { + relativePath := strings.TrimPrefix(pathStr, cacheDir) + webURL := baseURL + "/graph/" + relativePath + webPaths = append(webPaths, webURL) + } else { + // Fallback: just use the filename + filename := filepath.Base(pathStr) + webURL := baseURL + "/graph/" + filename + webPaths = append(webPaths, webURL) + } + } + } + + result["GraphPaths"] = webPaths + + // Re-encode to JSON + modifiedJSON, err := json.Marshal(result) + if err != nil { + log.Printf("Failed to re-encode JSON after graph path conversion: %v", err) + return jsonOutput + } + + return modifiedJSON +} diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go index 9e9d517..72fb0e6 100644 --- a/internal/api/integration_test.go +++ b/internal/api/integration_test.go @@ -66,7 +66,7 @@ func TestCallgraphIntegration(t *testing.T) { "repo": "https://github.com/openshift/sriov-network-device-plugin", "branchOrCommit": "95ebce39bc8d15f498abb0db0cb5b464db9a4887", "cve": "CVE-2024-45339", - "runFix": false, + "algo": "rta", } reqJSON, err := json.Marshal(requestBody) @@ -228,7 +228,7 @@ func TestCallgraphIntegrationVulnerable(t *testing.T) { "repo": "https://github.com/openshift/sriov-network-device-plugin", "branchOrCommit": "c600016ab638aab33bf02be5414f4174033c744a", "cve": "CVE-2024-45339", - "runFix": false, + "algo": "rta", } reqJSON, err := json.Marshal(requestBody) diff --git a/pkg/cmd/cg/scanner.go b/pkg/cmd/cg/scanner.go index d392c46..22e0ea3 100644 --- a/pkg/cmd/cg/scanner.go +++ b/pkg/cmd/cg/scanner.go @@ -3,20 +3,21 @@ // Call Graph Algorithms: // The scanner supports multiple call graph algorithms, configurable via the ALGO environment variable: // -// - vta (default): Variable Type Analysis - Most precise but slower. Recommended for accuracy. +// - rta (default): Rapid Type Analysis - Good balance of speed and precision. Best for reflection tracking. // - cha: Class Hierarchy Analysis - Fast but less precise. Good for large codebases where speed matters. -// - rta: Rapid Type Analysis - Good balance of speed and precision. Suitable for most use cases. +// - vta: Variable Type Analysis - Most precise for direct calls but slower. Less effective for reflection. // - static: Static analysis - Very fast but least precise (only direct calls). Use for quick scans. // // Usage: // -// export ALGO=rta # Use Rapid Type Analysis -// export ALGO=cha # Use Class Hierarchy Analysis -// export ALGO=static # Use static analysis -// # Default (no env var set) uses VTA +// export ALGO=vta # Use Variable Type Analysis +// export ALGO=cha # Use Class Hierarchy Analysis +// export ALGO=static # Use static analysis +// # Default (no env var set) uses RTA // // Algorithm Trade-offs: -// - Precision: static < cha < rta < vta +// - Precision for direct calls: static < cha < rta < vta +// - Reflection tracking: vta < static < cha < rta // - Speed: vta < rta < cha < static package cg @@ -31,8 +32,8 @@ import ( "io" "net/http" "os" + "os/exec" "path/filepath" - "regexp" "slices" "sort" "strings" @@ -166,8 +167,8 @@ func InitResult(cve, dir string, fix bool, library, symbols, fixversion string) // Check if it's a stdlib package if strings.HasPrefix(library, "crypto/") || strings.HasPrefix(library, "net/") || - strings.HasPrefix(library, "encoding/") || strings.HasPrefix(library, "os/") || - !strings.Contains(library, ".") { + strings.HasPrefix(library, "encoding/") || strings.HasPrefix(library, "os/") || + !strings.Contains(library, ".") { entry := r.AffectedImports[library] entry.Type = "stdlib" r.AffectedImports[library] = entry @@ -592,32 +593,17 @@ func (r *Result) isSymbolUsed(pkg, dir string, symbols, files []string) string { // Check for reflection-based usage reflectionRisks := r.detectReflectionVulnerabilities(pkg, dir, originalSymbols, files) - // Store reflection risks + // Store reflection risks for informational purposes only + // Note: Reflection risks do NOT affect IsVulnerable status - only call graph findings do if len(reflectionRisks) > 0 { r.Mu.Lock() r.ReflectionRisks = append(r.ReflectionRisks, reflectionRisks...) r.Mu.Unlock() } - // Update used imports with reflection findings - if directUsage == "true" || len(reflectionRisks) > 0 { - r.Mu.Lock() - if r.UsedImports == nil { - r.UsedImports = make(map[string]UsedImportsDetails) - } - entry := r.UsedImports[pkg] - - // Add reflection-detected symbols - for _, risk := range reflectionRisks { - entry.Symbols = append(entry.Symbols, risk.Symbol) - } - - r.UsedImports[pkg] = entry - r.Mu.Unlock() - return "true" - } - - return "false" + // Only return "true" if call graph found direct usage + // Reflection risks are informational only and don't determine vulnerability status + return directUsage } // checkDirectUsage handles the existing call graph analysis @@ -630,17 +616,24 @@ func (r *Result) checkDirectUsage(pkg, dir string, symbols []string, files []str return "unknown" } - // Convert to bytes for compatibility with existing matchSymbol function - out := []byte(callGraphOutput) + // Find entry points (main functions) from the call graph + entryPoints := extractEntryPoints(callGraphOutput) + if len(entryPoints) == 0 { + errMsg := fmt.Sprintf("No entry points (main functions) found in call graph for %s", dir) + r.Errors = append(r.Errors, errMsg) + return "unknown" + } var wg sync.WaitGroup - found := false + var mu sync.Mutex + foundAny := false for _, symbol := range symbols { wg.Add(1) go func(sym string) { defer wg.Done() - if matchSymbol(out, sym) { + // Check if symbol is reachable from any entry point using digraph somepath + if isSymbolReachableFromAny(callGraphOutput, entryPoints, sym) { r.Mu.Lock() if r.UsedImports == nil { r.UsedImports = make(map[string]UsedImportsDetails) @@ -649,47 +642,40 @@ func (r *Result) checkDirectUsage(pkg, dir string, symbols []string, files []str entry.Symbols = append(entry.Symbols, sym) r.UsedImports[pkg] = entry r.Mu.Unlock() - found = true + + mu.Lock() + foundAny = true + mu.Unlock() } }(symbol) } wg.Wait() - if found { + if foundAny { return "true" } return "false" } +// GenerateCallGraphForVisualization is a public wrapper for call graph generation for visualization +func (r *Result) GenerateCallGraphForVisualization(dir string, files []string) (string, error) { + return r.generateCallGraphWithLib(dir, files) +} + // generateCallGraphWithLib creates a call graph using the callgraph library func (r *Result) generateCallGraphWithLib(dir string, files []string) (string, error) { - // Determine package patterns to load based on the files - packagePatterns := make(map[string]bool) - - // Add the current directory and any subdirectories containing the files - packagePatterns["."] = true - for _, file := range files { - packageDir := filepath.Dir(file) - if packageDir != "." { - packagePatterns["./"+packageDir] = true - } - } - - var patterns []string - for pattern := range packagePatterns { - patterns = append(patterns, pattern) - } - // Load packages with comprehensive mode to handle all dependencies + // Use "./..." to load all packages in the module - this is required for + // RTA to properly track reflection-based calls like reflect.ValueOf(func).Call() cfg := &packages.Config{ Mode: packages.LoadAllSyntax, // This loads everything needed for analysis Dir: dir, Env: append(os.Environ(), "GOFLAGS=-mod=mod", "GOWORK=off"), } - // Load packages - pkgs, err := packages.Load(cfg, patterns...) + // Load all packages in the module (like callgraph binary does by default) + pkgs, err := packages.Load(cfg, "./...") if err != nil { return "", fmt.Errorf("failed to load packages: %v", err) } @@ -708,7 +694,7 @@ func (r *Result) generateCallGraphWithLib(dir string, files []string) (string, e // If no valid packages, try loading with less strict requirements if len(validPkgs) == 0 { - // Try with just the module root pattern + // Try with just the module root pattern with less strict mode cfg = &packages.Config{ Mode: packages.LoadSyntax, Dir: dir, @@ -760,28 +746,60 @@ func (r *Result) generateCallGraphWithLib(dir string, files []string) (string, e return output.String(), nil } -func matchSymbol(out []byte, symbol string) bool { - // Match symbol as either caller (at start or after space) or callee (before space or at end) - // Call graph format: "Caller Callee" - quotedSymbol := regexp.QuoteMeta(symbol) - // Pattern matches: (start or space) + symbol + (space or end) - pattern := regexp.MustCompile(`(^|\s)` + quotedSymbol + `(\s|$)`) - scanner := bufio.NewScanner(bytes.NewReader(out)) +// extractEntryPoints finds all main functions in the call graph +func extractEntryPoints(callGraphOutput string) []string { + var entryPoints []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(strings.NewReader(callGraphOutput)) for scanner.Scan() { - if pattern.MatchString(scanner.Text()) { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 1 { + caller := fields[0] + // Look for main functions (e.g., "command-line-arguments.main", "package.main") + if strings.HasSuffix(caller, ".main") && !seen[caller] { + entryPoints = append(entryPoints, caller) + seen[caller] = true + } + } + } + + return entryPoints +} + +// isSymbolReachableFromAny checks if a symbol is reachable from any entry point +func isSymbolReachableFromAny(callGraphOutput string, entryPoints []string, symbol string) bool { + for _, entryPoint := range entryPoints { + if isSymbolReachable(callGraphOutput, entryPoint, symbol) { return true } } return false } +// isSymbolReachable uses digraph somepath to check if there's a path from entry to symbol +func isSymbolReachable(callGraphOutput, entryPoint, symbol string) bool { + cmd := exec.Command("digraph", "somepath", entryPoint, symbol) + cmd.Stdin = strings.NewReader(callGraphOutput) + output, err := cmd.Output() + if err != nil { + // No path found or error occurred + return false + } + // If digraph finds a path, it outputs the path; empty means no path + return len(bytes.TrimSpace(output)) > 0 +} + // getCallGraphAlgorithm returns the algorithm to use for call graph generation // based on the ALGO environment variable. -// Supported algorithms: vta (default), cha, rta, static +// Supported algorithms: rta (default), cha, vta, static +// RTA is the default because it better handles reflection-based calls +// (e.g., reflect.ValueOf(func).Call()) compared to VTA. func getCallGraphAlgorithm() string { algo := os.Getenv("ALGO") if algo == "" { - algo = "vta" // default algorithm + algo = "rta" // default algorithm - better for reflection tracking } return strings.ToLower(algo) } @@ -796,16 +814,17 @@ func buildCallGraph(prog *ssa.Program, algo string) *callgraph.Graph { return cha.CallGraph(prog) case "rta": // Rapid Type Analysis - good balance of speed and precision + // Better at tracking reflection-based calls (e.g., reflect.ValueOf(func).Call()) return buildRTACallGraph(prog, allFuncs) case "static": // Static analysis - very fast but least precise (only direct calls) return static.CallGraph(prog) case "vta": - // Variable Type Analysis - most precise but slower (default) + // Variable Type Analysis - most precise for direct calls but slower return vta.CallGraph(allFuncs, nil) default: - // Default to VTA if unknown algorithm specified - return vta.CallGraph(allFuncs, nil) + // Default to RTA if unknown algorithm specified (best for reflection tracking) + return buildRTACallGraph(prog, allFuncs) } } @@ -1050,8 +1069,9 @@ func extractGoVersion(fixedVersion string) string { fixedVersion = strings.TrimPrefix(fixedVersion, "go") // Ensure it's a valid semver format (add "v" prefix if missing) - if regexp.MustCompile(`^\d+\.\d+`).MatchString(fixedVersion) { - fixedVersion = "v" + strings.TrimPrefix(fixedVersion, "v") + // Check if it starts with a digit (e.g., "1.21.4") + if len(fixedVersion) > 0 && fixedVersion[0] >= '0' && fixedVersion[0] <= '9' { + fixedVersion = "v" + fixedVersion } return fixedVersion diff --git a/pkg/cmd/cg/types.go b/pkg/cmd/cg/types.go index e4eb633..fa414a3 100644 --- a/pkg/cmd/cg/types.go +++ b/pkg/cmd/cg/types.go @@ -53,6 +53,7 @@ type Result struct { Unsafe bool `json:"unsafe"` Reflect bool `json:"reflect"` ReflectionRisks []ReflectionRisk `json:"reflection_risks,omitempty"` // New field + GraphPaths []string `json:"GraphPaths,omitempty"` // Paths to generated SVG graphs (one per symbol) Mu sync.Mutex `json:"-"` Summary string } diff --git a/site/index.html b/site/index.html index 88dd423..1f7fcaf 100644 --- a/site/index.html +++ b/site/index.html @@ -29,7 +29,7 @@

Golang Vulnerability Scanner

-
+
- + +
+
+ - diff --git a/site/script.js b/site/script.js index a2edee2..0828d28 100644 --- a/site/script.js +++ b/site/script.js @@ -144,23 +144,10 @@ document.addEventListener('DOMContentLoaded', function() { // Initialize form history manager window.formHistory = new FormHistoryManager(); - // Add event listener to CVE ID, Library, and Symbol fields to automatically set Run Fix state + // Add event listener to CVE ID, Library, and Symbol fields const cveInput = document.getElementById('cve'); const libraryInput = document.getElementById('library'); const symbolInput = document.getElementById('symbol'); - const fixSelect = document.getElementById('fix'); - - function updateRunFixState() { - const hasCve = cveInput.value.trim() !== ''; - const hasLibraryAndSymbol = libraryInput.value.trim() !== '' && symbolInput.value.trim() !== ''; - - if (hasCve || hasLibraryAndSymbol) { - fixSelect.disabled = false; - } else { - fixSelect.value = 'false'; - fixSelect.disabled = true; - } - } window.validateCVEInput = function() { const value = cveInput.value.trim(); @@ -182,17 +169,10 @@ document.addEventListener('DOMContentLoaded', function() { } } - // Initialize Run Fix state on page load - updateRunFixState(); - - // Update Run Fix state whenever CVE, Library, or Symbol inputs change + // Validate CVE input on change cveInput.addEventListener('input', function() { - updateRunFixState(); window.validateCVEInput(); }); - - libraryInput.addEventListener('input', updateRunFixState); - symbolInput.addEventListener('input', updateRunFixState); }); const sustainingQuestions = [ @@ -236,13 +216,20 @@ function syntaxHighlight(json) { json = json.replace(/\\"/g, """); - json = json + // Convert graph URLs to clickable links BEFORE other highlighting (to preserve URL structure) + // Matches http/https URLs ending with .svg + json = json.replace(/"(https?:\/\/[^"]+\.svg)"/g, function(match, url) { + return '"' + url + '"'; + }); + + // Apply syntax highlighting, but skip URLs inside tags + json = json .replace(/("(\w+)":)/g, '$1') - .replace(/(:\s*")([^"]*?)(")/g, ': "$2"') - .replace(/(:\s*(\d+))/g, ': $2'); + .replace(/(:\s*")(?![^"]*"$2"') + .replace(/(:\s*(\d+))(?![^<]*<\/a>)/g, ':$2'); - return json; - } + return json; +} function runScan() { if (scanInProgress) return; @@ -298,8 +285,8 @@ function runScan() { const repo = document.getElementById("repo").value.trim(); const branchOrCommit = document.getElementById("branchOrCommit").value.trim(); const cve = document.getElementById("cve").value.trim(); - // Automatically set fix to false if CVE ID is empty and no library/symbol override - const fix = (cve || (library && symbol)) ? document.getElementById("fix").value === "true" : false; + const algo = document.getElementById("algo").value; + const graph = (cve || (library && symbol)) ? document.getElementById("graph").value === "true" : false; const outputDiv = document.getElementById("output"); const progressContent = document.getElementById("progressContent"); const scanButton = document.getElementById("scanButton"); @@ -355,7 +342,7 @@ function runScan() { headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ repo, branchOrCommit, cve, library, symbol, fixversion, fix: fix, showProgress: true }) + body: JSON.stringify({ repo, branchOrCommit, cve, library, symbol, fixversion, algo, graph, showProgress: true }) }) .then(response => response.json()) .then(data => {