diff --git a/internal/adapter/checkstyle/converter.go b/internal/adapter/checkstyle/converter.go index d4ac1c8..c8d12f2 100644 --- a/internal/adapter/checkstyle/converter.go +++ b/internal/adapter/checkstyle/converter.go @@ -121,17 +121,17 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l // Separate modules into Checker-level and TreeWalker-level // Checker-level modules (NOT under TreeWalker) checkerLevelModules := map[string]bool{ - "LineLength": true, - "FileLength": true, - "FileTabCharacter": true, - "NewlineAtEndOfFile": true, - "UniqueProperties": true, - "OrderedProperties": true, - "Translation": true, - "SuppressWarningsFilter": true, + "LineLength": true, + "FileLength": true, + "FileTabCharacter": true, + "NewlineAtEndOfFile": true, + "UniqueProperties": true, + "OrderedProperties": true, + "Translation": true, + "SuppressWarningsFilter": true, "BeforeExecutionExclusionFileFilter": true, - "SuppressionFilter": true, - "SuppressionCommentFilter": true, + "SuppressionFilter": true, + "SuppressionCommentFilter": true, } var checkerModules []checkstyleModule @@ -255,8 +255,8 @@ Output: userPrompt := fmt.Sprintf("Convert this Java rule to Checkstyle module:\n\n%s", rule.Say) - // Call LLM with power model + low reasoning - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + // Call LLM with minimal complexity + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/eslint/converter.go b/internal/adapter/eslint/converter.go index 42d5b86..66decfb 100644 --- a/internal/adapter/eslint/converter.go +++ b/internal/adapter/eslint/converter.go @@ -217,8 +217,8 @@ Output: userPrompt += fmt.Sprintf("\nSeverity: %s", rule.Severity) } - // Call LLM with power model + low reasoning - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + // Call LLM with minimal complexity + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return "", nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/pmd/converter.go b/internal/adapter/pmd/converter.go index 09116c8..55a76e5 100644 --- a/internal/adapter/pmd/converter.go +++ b/internal/adapter/pmd/converter.go @@ -188,8 +188,8 @@ IMPORTANT: Return ONLY the JSON object. Do NOT include description, message, or userPrompt := fmt.Sprintf("Convert this Java rule to PMD rule reference:\n\n%s", rule.Say) - // Call LLM with power model + low reasoning - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + // Call LLM with minimal complexity + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/prettier/converter.go b/internal/adapter/prettier/converter.go index 010d95e..2d951f9 100644 --- a/internal/adapter/prettier/converter.go +++ b/internal/adapter/prettier/converter.go @@ -134,7 +134,7 @@ Output: userPrompt := fmt.Sprintf("Convert this rule to Prettier configuration:\n\n%s", rule.Say) // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/pylint/converter.go b/internal/adapter/pylint/converter.go index 7bd6b55..e07ece1 100644 --- a/internal/adapter/pylint/converter.go +++ b/internal/adapter/pylint/converter.go @@ -210,7 +210,7 @@ Output: } // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return "", nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/tsc/converter.go b/internal/adapter/tsc/converter.go index 9259e34..7dc4a9f 100644 --- a/internal/adapter/tsc/converter.go +++ b/internal/adapter/tsc/converter.go @@ -148,7 +148,7 @@ Output: userPrompt := fmt.Sprintf("Convert this rule to TypeScript compiler configuration:\n\n%s", rule.Say) // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMinimal).Execute(ctx) + response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/cmd/api_key.go b/internal/cmd/api_key.go index 43fb924..ec5c33e 100644 --- a/internal/cmd/api_key.go +++ b/internal/cmd/api_key.go @@ -16,11 +16,6 @@ func promptAPIKeySetup() { promptAPIKeyConfiguration(false) } -// promptAPIKeyIfNeeded checks if OpenAI API key is configured and prompts if not -func promptAPIKeyIfNeeded() { - promptAPIKeyConfiguration(true) -} - // promptAPIKeyConfiguration handles API key configuration with optional existence check func promptAPIKeyConfiguration(checkExisting bool) { envPath := filepath.Join(".sym", ".env") diff --git a/internal/cmd/convert.go b/internal/cmd/convert.go index 6787c9b..1b458ad 100644 --- a/internal/cmd/convert.go +++ b/internal/cmd/convert.go @@ -40,6 +40,8 @@ map them to appropriate linter rules.`, # Convert for specific linter sym convert -i user-policy.json --targets eslint + # Convert for Java with specific model + sym convert -i user-policy.json --targets checkstyle,pmd --openai-model gpt-5-mini # Convert for Java with specific model sym convert -i user-policy.json --targets checkstyle,pmd --openai-model gpt-5-mini @@ -131,18 +133,19 @@ func runNewConverter(userPolicy *schema.UserPolicy) error { convertOutputDir = ".sym" } - // Setup OpenAI client - apiKey, err := getAPIKey() - if err != nil { - return fmt.Errorf("OpenAI API key required: %w", err) - } - timeout := time.Duration(convertTimeout) * time.Second llmClient := llm.NewClient( - apiKey, llm.WithTimeout(timeout), ) + // Ensure at least one backend is available (MCP/CLI/API) + availabilityCtx, cancelAvailability := context.WithTimeout(context.Background(), timeout) + defer cancelAvailability() + + if err := llmClient.CheckAvailability(availabilityCtx); err != nil { + return fmt.Errorf("no available LLM backend for convert: %w\nTip: run 'sym init --setup-llm' or configure LLM_BACKEND / LLM_CLI / OPENAI_API_KEY in .sym/.env", err) + } + // Create new converter conv := converter.NewConverter(llmClient, convertOutputDir) diff --git a/internal/cmd/init.go b/internal/cmd/init.go index e1f2415..feae6c7 100644 --- a/internal/cmd/init.go +++ b/internal/cmd/init.go @@ -33,19 +33,23 @@ This command: } var ( - initForce bool - skipMCPRegister bool - registerMCPOnly bool - skipAPIKey bool - setupAPIKeyOnly bool + initForce bool + skipMCPRegister bool + registerMCPOnly bool + skipAPIKey bool + setupAPIKeyOnly bool + skipLLMSetup bool + setupLLMOnly bool ) func init() { initCmd.Flags().BoolVarP(&initForce, "force", "f", false, "Overwrite existing roles.json") initCmd.Flags().BoolVar(&skipMCPRegister, "skip-mcp", false, "Skip MCP server registration prompt") initCmd.Flags().BoolVar(®isterMCPOnly, "register-mcp", false, "Register MCP server only (skip roles/policy init)") - initCmd.Flags().BoolVar(&skipAPIKey, "skip-api-key", false, "Skip OpenAI API key configuration prompt") - initCmd.Flags().BoolVar(&setupAPIKeyOnly, "setup-api-key", false, "Setup OpenAI API key only (skip roles/policy init)") + initCmd.Flags().BoolVar(&skipAPIKey, "skip-api-key", false, "Skip OpenAI API key configuration prompt (deprecated, use --skip-llm)") + initCmd.Flags().BoolVar(&setupAPIKeyOnly, "setup-api-key", false, "Setup OpenAI API key only (deprecated, use --setup-llm)") + initCmd.Flags().BoolVar(&skipLLMSetup, "skip-llm", false, "Skip LLM backend configuration prompt") + initCmd.Flags().BoolVar(&setupLLMOnly, "setup-llm", false, "Setup LLM backend only (skip roles/policy init)") } func runInit(cmd *cobra.Command, args []string) { @@ -56,13 +60,20 @@ func runInit(cmd *cobra.Command, args []string) { return } - // API key setup only mode + // API key setup only mode (deprecated) if setupAPIKeyOnly { fmt.Println("๐Ÿ”‘ Setting up OpenAI API key...") promptAPIKeySetup() return } + // LLM setup only mode + if setupLLMOnly { + fmt.Println("๐Ÿค– Setting up LLM backend...") + promptLLMBackendSetup() + return + } + // Check if logged in if !config.IsLoggedIn() { fmt.Println("โŒ Not logged in") @@ -156,9 +167,9 @@ func runInit(cmd *cobra.Command, args []string) { promptMCPRegistration() } - // API key configuration prompt - if !skipAPIKey { - promptAPIKeyIfNeeded() + // LLM backend configuration prompt + if !skipLLMSetup && !skipAPIKey { + promptLLMBackendSetup() } // Show dashboard guide after all initialization is complete diff --git a/internal/cmd/llm.go b/internal/cmd/llm.go new file mode 100644 index 0000000..040f893 --- /dev/null +++ b/internal/cmd/llm.go @@ -0,0 +1,497 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "strings" + "time" + + "github.com/DevSymphony/sym-cli/internal/llm" + "github.com/DevSymphony/sym-cli/internal/llm/engine" + "github.com/manifoldco/promptui" + "github.com/spf13/cobra" +) + +var llmCmd = &cobra.Command{ + Use: "llm", + Short: "Manage LLM engine configuration", + Long: `Configure and manage LLM engines for Symphony. + +Symphony supports multiple LLM engines: + - MCP Sampling: Uses the host LLM when running as MCP server + - CLI: Uses local CLI tools (claude, gemini) + - API: Uses OpenAI API directly + +The default mode is 'auto' which tries engines in this order: +MCP Sampling โ†’ CLI โ†’ API`, +} + +var llmSetupCmd = &cobra.Command{ + Use: "setup", + Short: "Interactive LLM engine setup", + Long: `Interactively configure which LLM engine to use.`, + Run: runLLMSetup, +} + +var llmStatusCmd = &cobra.Command{ + Use: "status", + Short: "Show current LLM engine status", + Long: `Display the current LLM engine configuration and availability.`, + Run: runLLMStatus, +} + +var llmTestCmd = &cobra.Command{ + Use: "test", + Short: "Test LLM engine connection", + Long: `Send a test request to verify LLM engine is working.`, + Run: runLLMTest, +} + +func init() { + rootCmd.AddCommand(llmCmd) + llmCmd.AddCommand(llmSetupCmd) + llmCmd.AddCommand(llmStatusCmd) + llmCmd.AddCommand(llmTestCmd) +} + +func runLLMSetup(_ *cobra.Command, _ []string) { + fmt.Println("๐Ÿค– LLM Engine Configuration") + fmt.Println() + + // Load current config + cfg := llm.LoadLLMConfig() + + // Show current settings + fmt.Println("Current settings:") + fmt.Printf(" Engine mode: %s\n", cfg.Backend) + if cfg.CLI != "" { + fmt.Printf(" CLI: %s\n", cfg.CLI) + } + if cfg.Model != "" { + fmt.Printf(" Model: %s\n", cfg.Model) + } + if cfg.HasAPIKey() { + fmt.Println(" API Key: configured") + } else { + fmt.Println(" API Key: not set") + } + fmt.Println() + + // Show menu + items := []string{ + "Configure CLI tool", + "Set OpenAI API key", + "Change engine mode", + "Test current configuration", + "Reset to defaults", + "Exit", + } + + templates := &promptui.SelectTemplates{ + Label: "{{ . }}?", + Active: "โ–ธ {{ . | cyan }}", + Inactive: " {{ . }}", + Selected: "โœ“ {{ . | green }}", + } + + selectPrompt := promptui.Select{ + Label: "What would you like to configure", + Items: items, + Templates: templates, + Size: 6, + } + + index, _, err := selectPrompt.Run() + if err != nil { + fmt.Println("\nSetup cancelled") + return + } + + switch index { + case 0: + configureCLI(cfg) + case 1: + promptAPIKeySetup() + case 2: + configureEngineMode(cfg) + case 3: + runLLMTest(nil, nil) + case 4: + resetLLMConfig() + case 5: + fmt.Println("\nExiting setup") + } +} + +func configureCLI(cfg *llm.LLMConfig) { + fmt.Println("\n๐Ÿ”ง CLI Tool Configuration") + fmt.Println() + + // Detect available CLIs + clis := engine.DetectAvailableCLIs() + + // Build selection items + var items []string + var availableCLIs []engine.CLIInfo + + for _, cli := range clis { + status := "โœ— not found" + if cli.Available { + status = "โœ“ available" + if cli.Version != "" { + status = fmt.Sprintf("โœ“ %s", cli.Version) + } + } + items = append(items, fmt.Sprintf("%s (%s)", cli.Name, status)) + availableCLIs = append(availableCLIs, cli) + } + + items = append(items, "Skip CLI configuration") + + templates := &promptui.SelectTemplates{ + Label: "{{ . }}?", + Active: "โ–ธ {{ . | cyan }}", + Inactive: " {{ . }}", + Selected: "โœ“ {{ . | green }}", + } + + selectPrompt := promptui.Select{ + Label: "Select CLI tool to use", + Items: items, + Templates: templates, + Size: len(items), + } + + index, _, err := selectPrompt.Run() + if err != nil || index >= len(availableCLIs) { + fmt.Println("\nCLI configuration skipped") + return + } + + selectedCLI := availableCLIs[index] + + if !selectedCLI.Available { + fmt.Printf("\nโš ๏ธ %s is not installed or not in PATH\n", selectedCLI.Name) + fmt.Println("Please install it first and try again") + return + } + + // Update config + cfg.CLI = string(selectedCLI.Provider) + + // Get provider for default model + provider, _ := engine.GetProvider(selectedCLI.Provider) + if provider != nil { + cfg.Model = provider.DefaultModel + cfg.LargeModel = provider.LargeModel + } + + // Save config + if err := llm.SaveLLMConfig(cfg); err != nil { + fmt.Printf("\nโŒ Failed to save configuration: %v\n", err) + return + } + + fmt.Printf("\nโœ“ CLI engine configured: %s\n", selectedCLI.Name) + if cfg.Model != "" { + fmt.Printf(" Default model: %s\n", cfg.Model) + } + if cfg.LargeModel != "" { + fmt.Printf(" Large model: %s\n", cfg.LargeModel) + } + fmt.Println(" Configuration saved to .sym/.env") +} + +func configureEngineMode(cfg *llm.LLMConfig) { + fmt.Println("\nโš™๏ธ Engine Mode Configuration") + fmt.Println() + + items := []string{ + "auto - Automatically select best available engine", + "mcp - Always use MCP sampling (when available)", + "cli - Always use CLI tool", + "api - Always use OpenAI API", + } + + templates := &promptui.SelectTemplates{ + Label: "{{ . }}?", + Active: "โ–ธ {{ . | cyan }}", + Inactive: " {{ . }}", + Selected: "โœ“ {{ . | green }}", + } + + selectPrompt := promptui.Select{ + Label: "Select engine mode", + Items: items, + Templates: templates, + Size: 4, + } + + index, _, err := selectPrompt.Run() + if err != nil { + fmt.Println("\nEngine mode configuration cancelled") + return + } + + modes := []engine.Mode{ + engine.ModeAuto, + engine.ModeMCP, + engine.ModeCLI, + engine.ModeAPI, + } + + cfg.Backend = modes[index] + + // Save config + if err := llm.SaveLLMConfig(cfg); err != nil { + fmt.Printf("\nโŒ Failed to save configuration: %v\n", err) + return + } + + fmt.Printf("\nโœ“ Engine mode set to: %s\n", cfg.Backend) +} + +func resetLLMConfig() { + fmt.Println("\n๐Ÿ”„ Resetting LLM Configuration") + + // Confirm + prompt := promptui.Prompt{ + Label: "Are you sure you want to reset LLM configuration", + IsConfirm: true, + } + + result, err := prompt.Run() + if err != nil || strings.ToLower(result) != "y" { + fmt.Println("\nReset cancelled") + return + } + + // Save default config + cfg := llm.DefaultLLMConfig() + if err := llm.SaveLLMConfig(cfg); err != nil { + fmt.Printf("\nโŒ Failed to reset configuration: %v\n", err) + return + } + + fmt.Println("\nโœ“ LLM configuration reset to defaults") +} + +func runLLMStatus(_ *cobra.Command, _ []string) { + fmt.Println("๐Ÿค– LLM Engine Status") + fmt.Println() + + // Load config + cfg := llm.LoadLLMConfig() + + // Create client to check engines + client := llm.NewClient(llm.WithConfig(cfg), llm.WithVerbose(false)) + + fmt.Println("Configuration:") + fmt.Printf(" Engine mode: %s\n", cfg.Backend) + if cfg.CLI != "" { + fmt.Printf(" CLI provider: %s\n", cfg.CLI) + } + if cfg.Model != "" { + fmt.Printf(" Model: %s\n", cfg.Model) + } + fmt.Println() + + // Show engine availability + fmt.Println("Engine availability:") + + engines := client.GetEngines() + if len(engines) == 0 { + fmt.Println(" โš ๏ธ No engines configured") + } else { + for _, e := range engines { + status := "โœ— unavailable" + if e.IsAvailable() { + status = "โœ“ available" + } + fmt.Printf(" %s: %s\n", e.Name(), status) + } + } + + fmt.Println() + + // Show active engine + active := client.GetActiveEngine() + if active != nil { + fmt.Printf("Active engine: %s\n", active.Name()) + + caps := active.Capabilities() + fmt.Println("Capabilities:") + fmt.Printf(" Temperature: %v\n", caps.SupportsTemperature) + fmt.Printf(" Max tokens: %v\n", caps.SupportsMaxTokens) + fmt.Printf(" Complexity hint: %v\n", caps.SupportsComplexity) + } else { + fmt.Println("โš ๏ธ No active engine available") + } + + fmt.Println() + fmt.Println("๐Ÿ’ก Run 'sym llm setup' to configure engines") + fmt.Println("๐Ÿ’ก Run 'sym llm test' to verify connection") +} + +func runLLMTest(_ *cobra.Command, _ []string) { + fmt.Println("๐Ÿงช Testing LLM Engine Connection") + fmt.Println() + + // Load config + cfg := llm.LoadLLMConfig() + + // Create client + client := llm.NewClient(llm.WithConfig(cfg), llm.WithVerbose(true)) + + active := client.GetActiveEngine() + if active == nil { + fmt.Println("โŒ No LLM engine available") + fmt.Println() + fmt.Println("Please configure an engine:") + fmt.Println(" sym llm setup") + return + } + + fmt.Printf("Testing engine: %s\n\n", active.Name()) + + // Create test request + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + response, err := client.Request( + "You are a helpful assistant. Respond with exactly one word.", + "Say 'OK' to confirm you are working.", + ).Execute(ctx) + + if err != nil { + fmt.Printf("\nโŒ Test failed: %v\n", err) + os.Exit(1) + } + + fmt.Printf("\nโœ“ Test successful!\n") + fmt.Printf(" Response: %s\n", strings.TrimSpace(response)) +} + +// promptLLMBackendSetup is called from init command to setup LLM engine. +func promptLLMBackendSetup() { + fmt.Println("\n๐Ÿค– LLM Engine Configuration") + fmt.Println(" Symphony uses LLM for policy conversion and code validation.") + fmt.Println() + + // Detect available CLIs + clis := engine.DetectAvailableCLIs() + + // Check API key + cfg := llm.LoadLLMConfig() + hasAPIKey := cfg.HasAPIKey() + + // Show detected tools + fmt.Println(" Detected LLM tools:") + hasAnyCLI := false + for _, cli := range clis { + status := "โœ—" + if cli.Available { + status = "โœ“" + hasAnyCLI = true + } + version := "" + if cli.Version != "" { + version = fmt.Sprintf(" (%s)", cli.Version) + } + fmt.Printf(" %s %s%s\n", status, cli.Name, version) + } + + if hasAPIKey { + fmt.Println(" โœ“ OpenAI API key (configured)") + } else { + fmt.Println(" โœ— OpenAI API key (not set)") + } + fmt.Println() + + // If nothing available, skip + if !hasAnyCLI && !hasAPIKey { + fmt.Println(" โš ๏ธ No LLM engine available") + fmt.Println(" You can configure one later with: sym llm setup") + return + } + + // Build selection items + var items []string + var modes []engine.Mode + + items = append(items, "Auto (recommended) - Use best available engine") + modes = append(modes, engine.ModeAuto) + + for _, cli := range clis { + if cli.Available { + items = append(items, fmt.Sprintf("%s CLI", cli.Name)) + modes = append(modes, engine.ModeCLI) + } + } + + if hasAPIKey { + items = append(items, "OpenAI API") + modes = append(modes, engine.ModeAPI) + } + + items = append(items, "Skip (configure later)") + modes = append(modes, "") + + templates := &promptui.SelectTemplates{ + Label: "{{ . }}?", + Active: "โ–ธ {{ . | cyan }}", + Inactive: " {{ . }}", + Selected: "โœ“ {{ . | green }}", + } + + selectPrompt := promptui.Select{ + Label: "Select your preferred LLM engine", + Items: items, + Templates: templates, + Size: len(items), + } + + index, _, err := selectPrompt.Run() + if err != nil || modes[index] == "" { + fmt.Println("\n LLM engine configuration skipped") + fmt.Println(" Run 'sym llm setup' to configure later") + return + } + + // Update config + cfg.Backend = modes[index] + + // If CLI selected, set the specific CLI provider + if modes[index] == engine.ModeCLI { + // Find which CLI was selected + cliIndex := index - 1 // Account for "Auto" option + cliCount := 0 + for _, cli := range clis { + if cli.Available { + if cliCount == cliIndex { + cfg.CLI = string(cli.Provider) + provider, _ := engine.GetProvider(cli.Provider) + if provider != nil { + cfg.Model = provider.DefaultModel + cfg.LargeModel = provider.LargeModel + } + break + } + cliCount++ + } + } + } + + // Save config + if err := llm.SaveLLMConfig(cfg); err != nil { + fmt.Printf("\n โš ๏ธ Failed to save LLM configuration: %v\n", err) + return + } + + fmt.Printf("\n โœ“ LLM engine set to: %s\n", cfg.Backend) + if cfg.CLI != "" { + fmt.Printf(" CLI: %s\n", cfg.CLI) + } + fmt.Println(" Configuration saved to .sym/.env") +} diff --git a/internal/cmd/validate.go b/internal/cmd/validate.go index 7184b54..b805eea 100644 --- a/internal/cmd/validate.go +++ b/internal/cmd/validate.go @@ -73,18 +73,19 @@ func runValidate(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to parse policy: %w", err) } - // Get OpenAI API key - apiKey, err := getAPIKey() - if err != nil { - return fmt.Errorf("OpenAI API key not configured: %w\nTip: Run 'sym init' or set OPENAI_API_KEY in .sym/.env", err) - } - // Create LLM client llmClient := llm.NewClient( - apiKey, - llm.WithTimeout(time.Duration(validateTimeout)*time.Second), + llm.WithTimeout(time.Duration(validateTimeout) * time.Second), ) + // Ensure at least one backend is available (MCP/CLI/API) + availabilityCtx, cancel := context.WithTimeout(context.Background(), time.Duration(validateTimeout)*time.Second) + defer cancel() + + if err := llmClient.CheckAvailability(availabilityCtx); err != nil { + return fmt.Errorf("no available LLM backend for validate: %w\nTip: run 'sym init --setup-llm' or configure LLM_BACKEND / LLM_CLI / OPENAI_API_KEY in .sym/.env", err) + } + var changes []validator.GitChange if validateStaged { changes, err = validator.GetStagedChanges() diff --git a/internal/converter/converter.go b/internal/converter/converter.go index 94053f5..e32836e 100644 --- a/internal/converter/converter.go +++ b/internal/converter/converter.go @@ -405,8 +405,8 @@ Reason: Requires knowing which packages are "large"`, linterDescriptions, routin userPrompt := fmt.Sprintf("Rule: %s\nCategory: %s", rule.Say, rule.Category) - // Call LLM with power model + low reasoning (needs some thought for linter selection) - response, err := c.llmClient.Request(systemPrompt, userPrompt).WithPower(llm.ReasoningMedium).Execute(ctx) + // Call LLM with medium complexity (needs some thought for linter selection) + response, err := c.llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityLow).Execute(ctx) if err != nil { fmt.Fprintf(os.Stderr, "Warning: LLM routing failed for rule %s: %v\n", rule.ID, err) return []string{} // Will fall back to llm-validator diff --git a/internal/llm/client.go b/internal/llm/client.go index a0ea105..800f190 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -1,119 +1,205 @@ package llm import ( - "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" + "os" "time" "github.com/DevSymphony/sym-cli/internal/envutil" + "github.com/DevSymphony/sym-cli/internal/llm/engine" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" ) const ( - openAIAPIURL = "https://api.openai.com/v1/chat/completions" - defaultFastModel = "gpt-4o-mini" - defaultPowerModel = "gpt-5-mini" defaultMaxTokens = 1000 defaultTemperature = 1.0 defaultTimeout = 60 * time.Second ) -// Mode defines the LLM client mode -type Mode string - const ( - ModeAPI Mode = "api" - ModeMCP Mode = "mcp" + // ModeAPI uses OpenAI API. + ModeAPI = engine.ModeAPI + // ModeMCP uses MCP sampling. + ModeMCP = engine.ModeMCP + // ModeCLI uses CLI engine. + ModeCLI = engine.ModeCLI + // ModeAuto automatically selects the best available engine. + ModeAuto = engine.ModeAuto ) -// ReasoningEffort defines the reasoning effort level for o3-mini -type ReasoningEffort string - -const ( - ReasoningMinimal ReasoningEffort = "minimal" - ReasoningLow ReasoningEffort = "low" - ReasoningMedium ReasoningEffort = "medium" - ReasoningHigh ReasoningEffort = "high" -) - -// Client represents an LLM client +// Client represents an LLM client with fallback chain support. type Client struct { - mode Mode - apiKey string - fastModel string - powerModel string - httpClient *http.Client - mcpSession *mcpsdk.ServerSession + // Engine configuration + config *LLMConfig + mode engine.Mode + engines []engine.LLMEngine + mcpSession *mcpsdk.ServerSession + + // Default request parameters maxTokens int temperature float64 verbose bool } -// ClientOption is a functional option for configuring the client +// ClientOption is a functional option for configuring the client. type ClientOption func(*Client) -// WithMaxTokens sets the default max tokens +// WithMaxTokens sets the default max tokens. func WithMaxTokens(maxTokens int) ClientOption { return func(c *Client) { c.maxTokens = maxTokens } } -// WithTemperature sets the default temperature +// WithTemperature sets the default temperature. func WithTemperature(temperature float64) ClientOption { return func(c *Client) { c.temperature = temperature } } -// WithTimeout sets the HTTP client timeout -func WithTimeout(timeout time.Duration) ClientOption { - return func(c *Client) { c.httpClient.Timeout = timeout } +// WithTimeout sets the HTTP client timeout (for API engine). +func WithTimeout(_ time.Duration) ClientOption { + // Note: This is handled by individual engines now + return func(_ *Client) {} } -// WithVerbose enables verbose logging +// WithVerbose enables verbose logging. func WithVerbose(verbose bool) ClientOption { return func(c *Client) { c.verbose = verbose } } -// WithMCPSession sets the MCP session for MCP mode +// WithMCPSession sets the MCP session for MCP mode. func WithMCPSession(session *mcpsdk.ServerSession) ClientOption { return func(c *Client) { c.mcpSession = session - c.mode = ModeMCP + c.mode = engine.ModeMCP + } +} + +// WithConfig sets a custom LLM configuration. +func WithConfig(cfg *LLMConfig) ClientOption { + return func(c *Client) { + if cfg == nil { + return + } + c.config = cfg + if mode := cfg.GetEffectiveBackend(); mode != "" { + c.mode = mode + } } } -// NewClient creates a new LLM client -func NewClient(apiKey string, opts ...ClientOption) *Client { - if apiKey == "" { - apiKey = envutil.GetAPIKey("OPENAI_API_KEY") +// WithMode sets the preferred engine mode. +func WithMode(mode engine.Mode) ClientOption { + return func(c *Client) { + c.mode = mode } +} + +// NewClient creates a new LLM client. +func NewClient(opts ...ClientOption) *Client { + // Load default config + config := LoadLLMConfig() + + apiKey := envutil.GetAPIKey("OPENAI_API_KEY") + config.APIKey = apiKey client := &Client{ - mode: ModeAPI, - apiKey: apiKey, - fastModel: defaultFastModel, - powerModel: defaultPowerModel, - httpClient: &http.Client{Timeout: defaultTimeout}, + config: config, + mode: config.GetEffectiveBackend(), maxTokens: defaultMaxTokens, temperature: defaultTemperature, verbose: false, } + // Apply options for _, opt := range opts { opt(client) } + // Initialize engine chain + client.initEngines() + return client } -// Request creates a new request builder +// initEngines initializes the engine fallback chain based on configuration. +func (c *Client) initEngines() { + c.engines = []engine.LLMEngine{} + + // Determine which engines to include based on mode + switch c.mode { + case engine.ModeMCP: + c.addMCPEngine() + case engine.ModeCLI: + c.addCLIEngine() + case engine.ModeAPI: + c.addAPIEngine() + case engine.ModeAuto: + fallthrough + default: + // add all available engines + c.addMCPEngine() + c.addCLIEngine() + c.addAPIEngine() + } +} + +// addMCPEngine adds MCP engine if session is available. +func (c *Client) addMCPEngine() { + if c.mcpSession != nil { + eng := engine.NewMCPEngine(c.mcpSession, engine.WithMCPVerbose(c.verbose)) + c.engines = append(c.engines, eng) + } +} + +// addCLIEngine adds CLI engine if configured. +func (c *Client) addCLIEngine() { + if c.config.CLI != "" { + providerType := engine.CLIProviderType(c.config.CLI) + if !providerType.IsValid() { + return + } + + opts := []engine.CLIEngineOption{} + + if c.config.CLIPath != "" { + opts = append(opts, engine.WithCLIPath(c.config.CLIPath)) + } + + if c.config.Model != "" { + opts = append(opts, engine.WithCLIModel(c.config.Model)) + } + + if c.config.LargeModel != "" { + opts = append(opts, engine.WithCLILargeModel(c.config.LargeModel)) + } + + if c.verbose { + opts = append(opts, engine.WithCLIVerbose(true)) + } + + eng, err := engine.NewCLIEngine(providerType, opts...) + if err == nil && eng.IsAvailable() { + c.engines = append(c.engines, eng) + } + } +} + +// addAPIEngine adds API engine if key is available. +func (c *Client) addAPIEngine() { + apiKey := c.config.GetAPIKey() + if apiKey != "" { + eng := engine.NewAPIEngine(apiKey, engine.WithAPIVerbose(c.verbose)) + c.engines = append(c.engines, eng) + } +} + +// Request creates a new request builder. // // Usage: // -// client.Request(system, user).Execute(ctx) // fast model (gpt-4o-mini) -// client.Request(system, user).WithPower(llm.ReasoningMedium).Execute(ctx) // power model (o3-mini) +// client.Request(system, user).Execute(ctx) // default complexity +// client.Request(system, user).WithComplexity(llm.ComplexityMedium).Execute(ctx) // higher complexity +// client.Request(system, user).WithComplexity(engine.ComplexityHigh).Execute(ctx) // explicit complexity // client.Request(system, user).WithMaxTokens(2000).Execute(ctx) // custom tokens func (c *Client) Request(systemPrompt, userPrompt string) *RequestBuilder { return &RequestBuilder{ @@ -122,231 +208,120 @@ func (c *Client) Request(systemPrompt, userPrompt string) *RequestBuilder { user: userPrompt, maxTokens: c.maxTokens, temperature: c.temperature, - usePower: false, + complexity: engine.ComplexityLow, } } -// RequestBuilder builds and executes LLM requests with chain methods +// GetActiveEngine returns the first available engine. +func (c *Client) GetActiveEngine() engine.LLMEngine { + for _, e := range c.engines { + if e.IsAvailable() { + return e + } + } + return nil +} + +// GetEngines returns all configured engines. +func (c *Client) GetEngines() []engine.LLMEngine { + return c.engines +} + +// GetConfig returns the LLM configuration. +func (c *Client) GetConfig() *LLMConfig { + return c.config +} + +// CheckAvailability checks if any LLM engine is available. +func (c *Client) CheckAvailability(ctx context.Context) error { + eng := c.GetActiveEngine() + if eng == nil { + return fmt.Errorf("no available LLM engine") + } + + // For API engine, do a simple test request + if eng.Name() == "openai-api" { + _, err := c.Request("You are a test assistant.", "Say 'OK'").Execute(ctx) + if err != nil { + return fmt.Errorf("OpenAI API not available: %w", err) + } + } + + return nil +} + +// RequestBuilder builds and executes LLM requests with chain methods. type RequestBuilder struct { client *Client system string user string maxTokens int temperature float64 - usePower bool - effort ReasoningEffort + complexity engine.Complexity } -// WithPower enables power model (o3-mini) with specified reasoning effort -func (r *RequestBuilder) WithPower(effort ReasoningEffort) *RequestBuilder { - r.usePower = true - r.effort = effort +// WithComplexity sets the task complexity hint (engine-agnostic). +func (r *RequestBuilder) WithComplexity(c engine.Complexity) *RequestBuilder { + r.complexity = c return r } -// WithMaxTokens sets max tokens for this request +// WithMaxTokens sets max tokens for this request. func (r *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder { r.maxTokens = tokens return r } -// WithTemperature sets temperature for this request +// WithTemperature sets temperature for this request. func (r *RequestBuilder) WithTemperature(temp float64) *RequestBuilder { r.temperature = temp return r } -// Execute sends the request and returns the response +// Execute sends the request and returns the response. func (r *RequestBuilder) Execute(ctx context.Context) (string, error) { - if r.client.mode == ModeMCP { - return r.client.executeViaMCP(ctx, r) + req := &engine.Request{ + SystemPrompt: r.system, + UserPrompt: r.user, + MaxTokens: r.maxTokens, + Temperature: r.temperature, + Complexity: r.complexity, } - return r.client.executeViaAPI(ctx, r) -} - -// openAIRequest represents the OpenAI API request structure -type openAIRequest struct { - Model string `json:"model"` - Messages []openAIMessage `json:"messages"` - MaxTokens int `json:"max_completion_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` -} -type openAIMessage struct { - Role string `json:"role"` - Content string `json:"content"` + return r.client.executeWithFallback(ctx, req) } -type openAIResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - Error *struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code"` - } `json:"error,omitempty"` -} - -// executeViaAPI sends request via OpenAI API -func (c *Client) executeViaAPI(ctx context.Context, r *RequestBuilder) (string, error) { - if c.apiKey == "" { - return "", fmt.Errorf("OpenAI API key not configured") - } - - model := c.fastModel - if r.usePower { - model = c.powerModel - } - - reqBody := openAIRequest{ - Model: model, - Messages: []openAIMessage{ - {Role: "user", Content: r.system + "\n\n" + r.user}, - }, - MaxTokens: r.maxTokens, - Temperature: r.temperature, - } +// executeWithFallback tries engines in priority order. +func (c *Client) executeWithFallback(ctx context.Context, req *engine.Request) (string, error) { + var lastErr error - if r.usePower { - reqBody.ReasoningEffort = string(r.effort) - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", openAIAPIURL, bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - if c.verbose { - if r.usePower { - fmt.Printf("OpenAI API request:\n Model: %s\n Reasoning: %s\n Prompt length: %d chars\n", - model, r.effort, len(r.user)) - } else { - fmt.Printf("OpenAI API request:\n Model: %s\n Prompt length: %d chars\n", - model, len(r.user)) + for _, eng := range c.engines { + if !eng.IsAvailable() { + continue } - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp openAIResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return "", fmt.Errorf("failed to unmarshal response: %w", err) - } - - if apiResp.Error != nil { - return "", fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)", - apiResp.Error.Message, apiResp.Error.Type, apiResp.Error.Code) - } - - if len(apiResp.Choices) == 0 { - return "", fmt.Errorf("no choices in response") - } - - content := apiResp.Choices[0].Message.Content - - if c.verbose { - fmt.Printf("OpenAI API response:\n Tokens: %d\n Content length: %d chars\n", - apiResp.Usage.TotalTokens, len(content)) - } - - return content, nil -} - -// executeViaMCP sends request via MCP sampling -func (c *Client) executeViaMCP(ctx context.Context, r *RequestBuilder) (string, error) { - if c.mcpSession == nil { - return "", fmt.Errorf("MCP session not available") - } - if c.verbose { - fmt.Printf("MCP Sampling request:\n MaxTokens: %d\n Prompt length: %d chars\n", - r.maxTokens, len(r.user)) - } - - combinedPrompt := r.system + "\n\n" + r.user - - result, err := c.mcpSession.CreateMessage(ctx, &mcpsdk.CreateMessageParams{ - Messages: []*mcpsdk.SamplingMessage{ - { - Role: "user", - Content: &mcpsdk.TextContent{Text: combinedPrompt}, - }, - }, - MaxTokens: int64(r.maxTokens), - }) - if err != nil { - return "", fmt.Errorf("MCP sampling failed: %w", err) - } + result, err := eng.Execute(ctx, req) + if err == nil { + return result, nil + } - var response string - if textContent, ok := result.Content.(*mcpsdk.TextContent); ok { - response = textContent.Text - } else { - return "", fmt.Errorf("unexpected content type from MCP sampling") + lastErr = err + if c.verbose { + fmt.Fprintf(os.Stderr, "โš ๏ธ %s failed: %v, trying next engine...\n", eng.Name(), err) + } } - if c.verbose { - fmt.Printf("MCP Sampling response:\n Model: %s\n Content length: %d chars\n", - result.Model, len(response)) + if lastErr != nil { + return "", fmt.Errorf("all engines failed, last error: %w", lastErr) } - return response, nil + return "", fmt.Errorf("no available LLM engine configured") } -// CheckAvailability checks if the LLM is available -func (c *Client) CheckAvailability(ctx context.Context) error { - if c.mode == ModeMCP { - if c.mcpSession == nil { - return fmt.Errorf("MCP session not available") - } - return nil +// ExecuteDirect executes request on a specific engine without fallback. +func (c *Client) ExecuteDirect(ctx context.Context, eng engine.LLMEngine, req *engine.Request) (string, error) { + if !eng.IsAvailable() { + return "", fmt.Errorf("engine %s is not available", eng.Name()) } - - if c.apiKey == "" { - return fmt.Errorf("OPENAI_API_KEY environment variable not set") - } - - _, err := c.Request("You are a test assistant.", "Say 'OK'").Execute(ctx) - if err != nil { - return fmt.Errorf("OpenAI API not available: %w", err) - } - - return nil + return eng.Execute(ctx, req) } diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go new file mode 100644 index 0000000..c462d48 --- /dev/null +++ b/internal/llm/client_test.go @@ -0,0 +1,97 @@ +package llm + +import ( + "testing" + + "github.com/DevSymphony/sym-cli/internal/llm/engine" + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + t.Run("default_config", func(t *testing.T) { + client := NewClient() + assert.NotNil(t, client) + assert.NotNil(t, client.GetConfig()) + }) + + t.Run("with_options_and_config", func(t *testing.T) { + cfg := &LLMConfig{ + Backend: engine.ModeAPI, + APIKey: "sk-test", + } + client := NewClient(WithConfig(cfg), WithVerbose(true)) + assert.NotNil(t, client) + assert.Equal(t, engine.ModeAPI, client.config.Backend) + }) + + t.Run("with_mode_option", func(t *testing.T) { + client := NewClient(WithMode(engine.ModeAPI)) + assert.NotNil(t, client) + }) +} + +func TestClient_GetActiveEngine(t *testing.T) { + t.Run("with API engine", func(t *testing.T) { + cfg := &LLMConfig{ + Backend: engine.ModeAPI, + APIKey: "sk-test", + } + client := NewClient(WithConfig(cfg)) + eng := client.GetActiveEngine() + assert.NotNil(t, eng) + assert.Equal(t, "openai-api", eng.Name()) + }) + + t.Run("no engine available", func(t *testing.T) { + cfg := &LLMConfig{ + Backend: engine.ModeAPI, + // No API key + } + client := NewClient(WithConfig(cfg)) + eng := client.GetActiveEngine() + assert.Nil(t, eng) + }) +} + +func TestRequestBuilder(t *testing.T) { + client := NewClient() + + t.Run("basic request", func(t *testing.T) { + builder := client.Request("system", "user") + assert.NotNil(t, builder) + }) + + t.Run("with complexity", func(t *testing.T) { + builder := client.Request("system", "user"). + WithComplexity(engine.ComplexityHigh) + assert.NotNil(t, builder) + }) + + t.Run("with max tokens", func(t *testing.T) { + builder := client.Request("system", "user"). + WithMaxTokens(2000) + assert.NotNil(t, builder) + }) + + t.Run("with temperature", func(t *testing.T) { + builder := client.Request("system", "user"). + WithTemperature(0.7) + assert.NotNil(t, builder) + }) + + t.Run("chained options", func(t *testing.T) { + builder := client.Request("system", "user"). + WithComplexity(engine.ComplexityMedium). + WithMaxTokens(1500). + WithTemperature(0.8) + assert.NotNil(t, builder) + }) +} + +func TestModeConstants(t *testing.T) { + // Verify backward compatibility + assert.Equal(t, engine.ModeAPI, ModeAPI) + assert.Equal(t, engine.ModeMCP, ModeMCP) + assert.Equal(t, engine.ModeCLI, ModeCLI) + assert.Equal(t, engine.ModeAuto, ModeAuto) +} diff --git a/internal/llm/complexity.go b/internal/llm/complexity.go new file mode 100644 index 0000000..243d239 --- /dev/null +++ b/internal/llm/complexity.go @@ -0,0 +1,17 @@ +package llm + +import "github.com/DevSymphony/sym-cli/internal/llm/engine" + +// Complexity re-exports engine.Complexity for backward compatibility. +type Complexity = engine.Complexity + +const ( + // ComplexityMinimal is for trivial lookups. + ComplexityMinimal Complexity = engine.ComplexityMinimal + // ComplexityLow is for simple transformations. + ComplexityLow Complexity = engine.ComplexityLow + // ComplexityMedium is for moderate reasoning. + ComplexityMedium Complexity = engine.ComplexityMedium + // ComplexityHigh is for complex reasoning. + ComplexityHigh Complexity = engine.ComplexityHigh +) diff --git a/internal/llm/config.go b/internal/llm/config.go new file mode 100644 index 0000000..fe6eca4 --- /dev/null +++ b/internal/llm/config.go @@ -0,0 +1,366 @@ +package llm + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/DevSymphony/sym-cli/internal/llm/engine" +) + +const ( + // Default .sym/.env file location relative to repo root + defaultEnvFile = ".sym/.env" + + // Environment variable keys + envKeyLLMBackend = "LLM_BACKEND" + envKeyLLMCLI = "LLM_CLI" + envKeyLLMCLIPath = "LLM_CLI_PATH" + envKeyLLMModel = "LLM_MODEL" + envKeyLLMLarge = "LLM_LARGE_MODEL" + envKeyAPIKey = "OPENAI_API_KEY" +) + +// LLMConfig holds LLM engine configuration. +type LLMConfig struct { + // Backend is the preferred engine mode (auto, mcp, cli, api). + Backend engine.Mode `json:"backend"` + + // CLI is the CLI provider type (claude, gemini). + CLI string `json:"cli"` + + // CLIPath is a custom path to the CLI executable (optional). + CLIPath string `json:"cli_path"` + + // Model is the default model name for CLI engine. + Model string `json:"model"` + + // LargeModel is the model for high complexity tasks (optional). + LargeModel string `json:"large_model"` + + // APIKey is loaded from environment (not saved to config). + APIKey string `json:"-"` +} + +// DefaultLLMConfig returns the default configuration. +func DefaultLLMConfig() *LLMConfig { + return &LLMConfig{ + Backend: engine.ModeAuto, + CLI: "", + CLIPath: "", + Model: "", + } +} + +// LoadLLMConfig loads LLM configuration from .sym/.env file and environment. +func LoadLLMConfig() *LLMConfig { + cfg := DefaultLLMConfig() + + // Load from .sym/.env file first + envPath := defaultEnvFile + loadConfigFromEnvFile(envPath, cfg) + + // Override with system environment variables + loadConfigFromEnv(cfg) + + return cfg +} + +// LoadLLMConfigFromDir loads LLM configuration from a specific directory. +func LoadLLMConfigFromDir(dir string) *LLMConfig { + cfg := DefaultLLMConfig() + + // Load from .env file in the specified directory + envPath := filepath.Join(dir, ".env") + loadConfigFromEnvFile(envPath, cfg) + + // Override with system environment variables + loadConfigFromEnv(cfg) + + return cfg +} + +// loadConfigFromEnvFile reads config values from .env file. +func loadConfigFromEnvFile(envPath string, cfg *LLMConfig) { + file, err := os.Open(envPath) + if err != nil { + return // File doesn't exist, use defaults + } + defer func() { _ = file.Close() }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // Skip comments and empty lines + if len(line) == 0 || line[0] == '#' { + continue + } + + // Parse key=value + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch key { + case envKeyLLMBackend: + if engine.Mode(value).IsValid() { + cfg.Backend = engine.Mode(value) + } + case envKeyLLMCLI: + cfg.CLI = value + case envKeyLLMCLIPath: + cfg.CLIPath = value + case envKeyLLMModel: + cfg.Model = value + case envKeyLLMLarge: + cfg.LargeModel = value + case envKeyAPIKey: + cfg.APIKey = value + } + } +} + +// loadConfigFromEnv loads config from system environment variables. +func loadConfigFromEnv(cfg *LLMConfig) { + if backend := os.Getenv(envKeyLLMBackend); backend != "" { + if engine.Mode(backend).IsValid() { + cfg.Backend = engine.Mode(backend) + } + } + + if cli := os.Getenv(envKeyLLMCLI); cli != "" { + cfg.CLI = cli + } + + if cliPath := os.Getenv(envKeyLLMCLIPath); cliPath != "" { + cfg.CLIPath = cliPath + } + + if model := os.Getenv(envKeyLLMModel); model != "" { + cfg.Model = model + } + + if large := os.Getenv(envKeyLLMLarge); large != "" { + cfg.LargeModel = large + } + + if apiKey := os.Getenv(envKeyAPIKey); apiKey != "" { + cfg.APIKey = apiKey + } +} + +// SaveLLMConfig saves LLM configuration to .sym/.env file. +func SaveLLMConfig(cfg *LLMConfig) error { + return SaveLLMConfigToDir(".sym", cfg) +} + +// SaveLLMConfigToDir saves LLM configuration to a specific directory. +func SaveLLMConfigToDir(dir string, cfg *LLMConfig) error { + // Ensure directory exists + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + envPath := filepath.Join(dir, ".env") + + // Read existing content + existingLines, existingKeys := readExistingEnvFile(envPath) + + // Prepare new values + newValues := map[string]string{} + + if cfg.Backend != "" && cfg.Backend != engine.ModeAuto { + newValues[envKeyLLMBackend] = string(cfg.Backend) + } + + if cfg.CLI != "" { + newValues[envKeyLLMCLI] = cfg.CLI + } + + if cfg.CLIPath != "" { + newValues[envKeyLLMCLIPath] = cfg.CLIPath + } + + if cfg.Model != "" { + newValues[envKeyLLMModel] = cfg.Model + } + + if cfg.LargeModel != "" { + newValues[envKeyLLMLarge] = cfg.LargeModel + } + + // Build output lines + var outputLines []string + + // Update existing lines + for _, line := range existingLines { + trimmed := strings.TrimSpace(line) + + // Keep comments and empty lines + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + outputLines = append(outputLines, line) + continue + } + + // Parse key + parts := strings.SplitN(trimmed, "=", 2) + if len(parts) != 2 { + outputLines = append(outputLines, line) + continue + } + + key := strings.TrimSpace(parts[0]) + + // Check if we have a new value for this key + if newValue, ok := newValues[key]; ok { + outputLines = append(outputLines, fmt.Sprintf("%s=%s", key, newValue)) + delete(newValues, key) // Mark as processed + } else { + outputLines = append(outputLines, line) + } + } + + // Add LLM config section header if needed + hasLLMSection := false + for key := range existingKeys { + if strings.HasPrefix(key, "LLM_") { + hasLLMSection = true + break + } + } + + // Add new keys that weren't in the file + if len(newValues) > 0 { + if !hasLLMSection { + outputLines = append(outputLines, "", "# LLM Backend Configuration") + } + + for key, value := range newValues { + outputLines = append(outputLines, fmt.Sprintf("%s=%s", key, value)) + } + } + + // Write to file + content := strings.Join(outputLines, "\n") + if !strings.HasSuffix(content, "\n") { + content += "\n" + } + + return os.WriteFile(envPath, []byte(content), 0600) +} + +// readExistingEnvFile reads existing .env file content. +func readExistingEnvFile(envPath string) ([]string, map[string]bool) { + var lines []string + keys := make(map[string]bool) + + file, err := os.Open(envPath) + if err != nil { + return lines, keys + } + defer func() { _ = file.Close() }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + lines = append(lines, line) + + // Track existing keys + trimmed := strings.TrimSpace(line) + if len(trimmed) > 0 && !strings.HasPrefix(trimmed, "#") { + parts := strings.SplitN(trimmed, "=", 2) + if len(parts) == 2 { + keys[strings.TrimSpace(parts[0])] = true + } + } + } + + return lines, keys +} + +// GetAPIKey returns the API key from config or environment. +func (c *LLMConfig) GetAPIKey() string { + if c.APIKey != "" { + return c.APIKey + } + return os.Getenv(envKeyAPIKey) +} + +// HasCLI returns true if CLI is configured. +func (c *LLMConfig) HasCLI() bool { + return c.CLI != "" +} + +// HasAPIKey returns true if API key is available. +func (c *LLMConfig) HasAPIKey() bool { + return c.GetAPIKey() != "" +} + +// GetEffectiveBackend returns the actual engine to use based on availability. +func (c *LLMConfig) GetEffectiveBackend() engine.Mode { + if c.Backend != engine.ModeAuto { + return c.Backend + } + + // Auto mode: prefer CLI if available, then API + if c.HasCLI() { + return engine.ModeCLI + } + + if c.HasAPIKey() { + return engine.ModeAPI + } + + return engine.ModeAuto +} + +// Validate checks if the configuration is valid. +func (c *LLMConfig) Validate() error { + if c.Backend != "" && !c.Backend.IsValid() { + return fmt.Errorf("invalid engine mode: %s", c.Backend) + } + + if c.CLI != "" && !engine.CLIProviderType(c.CLI).IsValid() { + return fmt.Errorf("unsupported CLI provider: %s", c.CLI) + } + + return nil +} + +// String returns a human-readable representation of the config. +func (c *LLMConfig) String() string { + var parts []string + + parts = append(parts, fmt.Sprintf("Backend: %s", c.Backend)) + + if c.CLI != "" { + parts = append(parts, fmt.Sprintf("CLI: %s", c.CLI)) + } + + if c.CLIPath != "" { + parts = append(parts, fmt.Sprintf("CLI Path: %s", c.CLIPath)) + } + + if c.Model != "" { + parts = append(parts, fmt.Sprintf("Model: %s", c.Model)) + } + + if c.LargeModel != "" { + parts = append(parts, fmt.Sprintf("Large Model: %s", c.LargeModel)) + } + + if c.HasAPIKey() { + parts = append(parts, "API Key: configured") + } else { + parts = append(parts, "API Key: not set") + } + + return strings.Join(parts, ", ") +} diff --git a/internal/llm/config_test.go b/internal/llm/config_test.go new file mode 100644 index 0000000..b5bc4e1 --- /dev/null +++ b/internal/llm/config_test.go @@ -0,0 +1,160 @@ +package llm + +import ( + "os" + "path/filepath" + "testing" + + "github.com/DevSymphony/sym-cli/internal/llm/engine" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultLLMConfig(t *testing.T) { + cfg := DefaultLLMConfig() + + assert.Equal(t, engine.ModeAuto, cfg.Backend) + assert.Empty(t, cfg.CLI) + assert.Empty(t, cfg.CLIPath) + assert.Empty(t, cfg.Model) +} + +func TestLLMConfig_HasCLI(t *testing.T) { + t.Run("with CLI", func(t *testing.T) { + cfg := &LLMConfig{CLI: "claude"} + assert.True(t, cfg.HasCLI()) + }) + + t.Run("without CLI", func(t *testing.T) { + cfg := &LLMConfig{} + assert.False(t, cfg.HasCLI()) + }) +} + +func TestLLMConfig_HasAPIKey(t *testing.T) { + t.Run("with API key in config", func(t *testing.T) { + cfg := &LLMConfig{APIKey: "sk-test"} + assert.True(t, cfg.HasAPIKey()) + }) + + t.Run("without API key", func(t *testing.T) { + cfg := &LLMConfig{} + assert.False(t, cfg.HasAPIKey()) + }) +} + +func TestLLMConfig_GetEffectiveBackend(t *testing.T) { + t.Run("explicit mode", func(t *testing.T) { + cfg := &LLMConfig{Backend: engine.ModeCLI} + assert.Equal(t, engine.ModeCLI, cfg.GetEffectiveBackend()) + }) + + t.Run("auto with CLI", func(t *testing.T) { + cfg := &LLMConfig{Backend: engine.ModeAuto, CLI: "claude"} + assert.Equal(t, engine.ModeCLI, cfg.GetEffectiveBackend()) + }) + + t.Run("auto with API key", func(t *testing.T) { + cfg := &LLMConfig{Backend: engine.ModeAuto, APIKey: "sk-test"} + assert.Equal(t, engine.ModeAPI, cfg.GetEffectiveBackend()) + }) + + t.Run("auto with nothing", func(t *testing.T) { + cfg := &LLMConfig{Backend: engine.ModeAuto} + assert.Equal(t, engine.ModeAuto, cfg.GetEffectiveBackend()) + }) +} + +func TestLLMConfig_Validate(t *testing.T) { + t.Run("valid config", func(t *testing.T) { + cfg := &LLMConfig{ + Backend: engine.ModeAuto, + CLI: "claude", + } + assert.NoError(t, cfg.Validate()) + }) + + t.Run("invalid backend", func(t *testing.T) { + cfg := &LLMConfig{Backend: engine.Mode("invalid")} + assert.Error(t, cfg.Validate()) + }) + + t.Run("invalid CLI provider", func(t *testing.T) { + cfg := &LLMConfig{CLI: "invalid-cli"} + assert.Error(t, cfg.Validate()) + }) + + t.Run("empty config is valid", func(t *testing.T) { + cfg := &LLMConfig{} + assert.NoError(t, cfg.Validate()) + }) +} + +func TestLLMConfig_String(t *testing.T) { + cfg := &LLMConfig{ + Backend: engine.ModeAuto, + CLI: "claude", + Model: "claude-3-opus", + } + + str := cfg.String() + assert.Contains(t, str, "Backend: auto") + assert.Contains(t, str, "CLI: claude") + assert.Contains(t, str, "Model: claude-3-opus") +} + +func TestSaveLLMConfig(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &LLMConfig{ + Backend: engine.ModeCLI, + CLI: "claude", + Model: "claude-3-opus", + LargeModel: "claude-3-opus", + } + + err := SaveLLMConfigToDir(tmpDir, cfg) + require.NoError(t, err) + + // Verify file was created + envPath := filepath.Join(tmpDir, ".env") + _, err = os.Stat(envPath) + require.NoError(t, err) + + // Read and verify content + content, err := os.ReadFile(envPath) + require.NoError(t, err) + + assert.Contains(t, string(content), "LLM_BACKEND=cli") + assert.Contains(t, string(content), "LLM_CLI=claude") + assert.Contains(t, string(content), "LLM_MODEL=claude-3-opus") +} + +func TestLoadLLMConfigFromDir(t *testing.T) { + tmpDir := t.TempDir() + + // Create .env file + envContent := `# Test config +LLM_BACKEND=cli +LLM_CLI=gemini +LLM_MODEL=gemini-pro +` + envPath := filepath.Join(tmpDir, ".env") + err := os.WriteFile(envPath, []byte(envContent), 0600) + require.NoError(t, err) + + cfg := LoadLLMConfigFromDir(tmpDir) + + assert.Equal(t, engine.ModeCLI, cfg.Backend) + assert.Equal(t, "gemini", cfg.CLI) + assert.Equal(t, "gemini-pro", cfg.Model) +} + +func TestLoadLLMConfigFromDir_NonExistent(t *testing.T) { + cfg := LoadLLMConfigFromDir("/nonexistent/path") + + // Should return defaults + assert.Equal(t, engine.ModeAuto, cfg.Backend) + assert.Empty(t, cfg.CLI) +} + diff --git a/internal/llm/engine/api.go b/internal/llm/engine/api.go new file mode 100644 index 0000000..6c4227f --- /dev/null +++ b/internal/llm/engine/api.go @@ -0,0 +1,247 @@ +package engine + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +const ( + openAIAPIURL = "https://api.openai.com/v1/chat/completions" + defaultAPIFastModel = "gpt-4o-mini" + defaultAPIPowerModel = "gpt-5-mini" + defaultAPITimeout = 60 * time.Second + defaultAPIMaxTokens = 1000 + defaultAPITemperature = 1.0 +) + +// APIEngine implements LLMEngine interface for OpenAI API. +type APIEngine struct { + apiKey string + fastModel string + powerModel string + httpClient *http.Client + maxTokens int + temperature float64 + verbose bool +} + +// APIEngineOption is a functional option for APIEngine. +type APIEngineOption func(*APIEngine) + +// WithAPIFastModel sets the fast model. +func WithAPIFastModel(model string) APIEngineOption { + return func(e *APIEngine) { e.fastModel = model } +} + +// WithAPIPowerModel sets the power model. +func WithAPIPowerModel(model string) APIEngineOption { + return func(e *APIEngine) { e.powerModel = model } +} + +// WithAPITimeout sets the HTTP client timeout. +func WithAPITimeout(timeout time.Duration) APIEngineOption { + return func(e *APIEngine) { e.httpClient.Timeout = timeout } +} + +// WithAPIVerbose enables verbose logging. +func WithAPIVerbose(verbose bool) APIEngineOption { + return func(e *APIEngine) { e.verbose = verbose } +} + +// NewAPIEngine creates a new OpenAI API engine. +func NewAPIEngine(apiKey string, opts ...APIEngineOption) *APIEngine { + e := &APIEngine{ + apiKey: apiKey, + fastModel: defaultAPIFastModel, + powerModel: defaultAPIPowerModel, + httpClient: &http.Client{Timeout: defaultAPITimeout}, + maxTokens: defaultAPIMaxTokens, + temperature: defaultAPITemperature, + verbose: false, + } + + for _, opt := range opts { + opt(e) + } + + return e +} + +// Name returns the engine identifier. +func (e *APIEngine) Name() string { + return "openai-api" +} + +// IsAvailable checks if the engine can be used. +func (e *APIEngine) IsAvailable() bool { + return e.apiKey != "" +} + +// Capabilities returns engine capabilities. +func (e *APIEngine) Capabilities() Capabilities { + return Capabilities{ + SupportsTemperature: true, + SupportsMaxTokens: true, + SupportsComplexity: true, + SupportsStreaming: true, + MaxContextLength: 128000, + Models: []string{e.fastModel, e.powerModel}, + } +} + +// Execute sends the request via OpenAI API. +func (e *APIEngine) Execute(ctx context.Context, req *Request) (string, error) { + if e.apiKey == "" { + return "", fmt.Errorf("OpenAI API key not configured") + } + + // Select model based on complexity + model := e.fastModel + var reasoningEffort string + + switch req.Complexity { + case ComplexityMinimal: + model = e.fastModel + reasoningEffort = "minimal" + case ComplexityLow: + model = e.fastModel + case ComplexityMedium: + model = e.powerModel + reasoningEffort = "low" + case ComplexityHigh: + model = e.powerModel + reasoningEffort = "medium" + } + + // Build request body + maxTokens := req.MaxTokens + if maxTokens == 0 { + maxTokens = e.maxTokens + } + + temperature := req.Temperature + if temperature == 0 { + temperature = e.temperature + } + + apiReq := openAIAPIRequest{ + Model: model, + Messages: []openAIAPIMessage{ + {Role: "user", Content: req.CombinedPrompt()}, + }, + MaxTokens: maxTokens, + Temperature: temperature, + } + + if reasoningEffort != "" { + apiReq.ReasoningEffort = reasoningEffort + } + + jsonData, err := json.Marshal(apiReq) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, openAIAPIURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) + + if e.verbose { + fmt.Fprintf(os.Stderr, "OpenAI API request:\n Model: %s\n Complexity: %s\n Prompt length: %d chars\n", + model, req.Complexity, len(req.UserPrompt)) + } + + resp, err := e.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(body)) + } + + var apiResp openAIAPIResponse + if err := json.Unmarshal(body, &apiResp); err != nil { + return "", fmt.Errorf("failed to unmarshal response: %w", err) + } + + if apiResp.Error != nil { + return "", fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)", + apiResp.Error.Message, apiResp.Error.Type, apiResp.Error.Code) + } + + if len(apiResp.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + content := apiResp.Choices[0].Message.Content + + if e.verbose { + fmt.Fprintf(os.Stderr, "OpenAI API response:\n Tokens: %d\n Content length: %d chars\n", + apiResp.Usage.TotalTokens, len(content)) + } + + return content, nil +} + +// SetVerbose sets verbose mode. +func (e *APIEngine) SetVerbose(verbose bool) { + e.verbose = verbose +} + +// openAIAPIRequest represents the OpenAI API request structure. +type openAIAPIRequest struct { + Model string `json:"model"` + Messages []openAIAPIMessage `json:"messages"` + MaxTokens int `json:"max_completion_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` +} + +// openAIAPIMessage represents a message in the OpenAI API request. +type openAIAPIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// openAIAPIResponse represents the OpenAI API response structure. +type openAIAPIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error,omitempty"` +} diff --git a/internal/llm/engine/api_test.go b/internal/llm/engine/api_test.go new file mode 100644 index 0000000..121df05 --- /dev/null +++ b/internal/llm/engine/api_test.go @@ -0,0 +1,64 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewAPIEngine(t *testing.T) { + t.Run("with api key", func(t *testing.T) { + engine := NewAPIEngine("sk-test-key") + assert.NotNil(t, engine) + assert.Equal(t, "openai-api", engine.Name()) + assert.True(t, engine.IsAvailable()) + }) + + t.Run("without api key", func(t *testing.T) { + engine := NewAPIEngine("") + assert.NotNil(t, engine) + assert.False(t, engine.IsAvailable()) + }) + + t.Run("with options", func(t *testing.T) { + engine := NewAPIEngine("sk-test-key", + WithAPIFastModel("gpt-4o"), + WithAPIPowerModel("o3-mini"), + WithAPIVerbose(true), + ) + assert.NotNil(t, engine) + caps := engine.Capabilities() + assert.Contains(t, caps.Models, "gpt-4o") + assert.Contains(t, caps.Models, "o3-mini") + }) +} + +func TestAPIEngine_Capabilities(t *testing.T) { + engine := NewAPIEngine("sk-test-key") + caps := engine.Capabilities() + + assert.True(t, caps.SupportsTemperature) + assert.True(t, caps.SupportsMaxTokens) + assert.True(t, caps.SupportsComplexity) + assert.True(t, caps.SupportsStreaming) + assert.Equal(t, 128000, caps.MaxContextLength) + assert.Len(t, caps.Models, 2) +} + +func TestAPIEngine_Name(t *testing.T) { + engine := NewAPIEngine("sk-test-key") + assert.Equal(t, "openai-api", engine.Name()) +} + +func TestAPIEngine_IsAvailable(t *testing.T) { + t.Run("available with key", func(t *testing.T) { + engine := NewAPIEngine("sk-test-key") + assert.True(t, engine.IsAvailable()) + }) + + t.Run("not available without key", func(t *testing.T) { + engine := NewAPIEngine("") + assert.False(t, engine.IsAvailable()) + }) +} + diff --git a/internal/llm/engine/cli.go b/internal/llm/engine/cli.go new file mode 100644 index 0000000..1d605bc --- /dev/null +++ b/internal/llm/engine/cli.go @@ -0,0 +1,224 @@ +package engine + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "time" + + "github.com/DevSymphony/sym-cli/internal/llm/engine/cliprovider" +) + +const ( + defaultCLITimeout = 120 * time.Second +) + +// Re-export CLI provider types for backward compatibility. +type CLIProviderType = cliprovider.Type + +const ( + // ProviderClaude is the Claude CLI provider. + ProviderClaude CLIProviderType = cliprovider.TypeClaude + // ProviderGemini is the Gemini CLI provider. + ProviderGemini CLIProviderType = cliprovider.TypeGemini +) + +// CLIProvider is an alias to cliprovider.Provider. +type CLIProvider = cliprovider.Provider + +// CLIInfo is an alias to cliprovider.Info. +type CLIInfo = cliprovider.Info + +// SupportedProviders returns all supported CLI providers. +func SupportedProviders() map[CLIProviderType]*CLIProvider { + return cliprovider.Supported() +} + +// GetProvider returns the provider for the given type. +func GetProvider(providerType CLIProviderType) (*CLIProvider, error) { + return cliprovider.Get(providerType) +} + +// DetectAvailableCLIs scans for installed CLI tools. +func DetectAvailableCLIs() []CLIInfo { + return cliprovider.Detect() +} + +// GetProviderByCommand finds a provider by its command name. +func GetProviderByCommand(command string) (*CLIProvider, error) { + return cliprovider.GetByCommand(command) +} + +// CLIEngine implements LLMEngine interface for CLI-based LLM tools. +type CLIEngine struct { + provider *cliprovider.Provider + model string + largeModel string + timeout time.Duration + verbose bool + customPath string +} + +// CLIEngineOption is a functional option for CLIEngine. +type CLIEngineOption func(*CLIEngine) + +// WithCLIModel sets the default model. +func WithCLIModel(model string) CLIEngineOption { + return func(e *CLIEngine) { e.model = model } +} + +// WithCLILargeModel sets the model for high complexity tasks. +func WithCLILargeModel(model string) CLIEngineOption { + return func(e *CLIEngine) { e.largeModel = model } +} + +// WithCLITimeout sets the execution timeout. +func WithCLITimeout(timeout time.Duration) CLIEngineOption { + return func(e *CLIEngine) { e.timeout = timeout } +} + +// WithCLIVerbose enables verbose logging. +func WithCLIVerbose(verbose bool) CLIEngineOption { + return func(e *CLIEngine) { e.verbose = verbose } +} + +// WithCLIPath sets a custom path to the CLI executable. +func WithCLIPath(path string) CLIEngineOption { + return func(e *CLIEngine) { e.customPath = path } +} + +// NewCLIEngine creates a new CLI engine for the given provider. +func NewCLIEngine(providerType CLIProviderType, opts ...CLIEngineOption) (*CLIEngine, error) { + provider, err := cliprovider.Get(providerType) + if err != nil { + return nil, err + } + + e := &CLIEngine{ + provider: provider, + model: provider.DefaultModel, + largeModel: provider.LargeModel, + timeout: defaultCLITimeout, + verbose: false, + } + + for _, opt := range opts { + opt(e) + } + + return e, nil +} + +// Name returns the engine identifier. +func (e *CLIEngine) Name() string { + return fmt.Sprintf("cli-%s", e.provider.Type) +} + +// IsAvailable checks if the engine can be used. +func (e *CLIEngine) IsAvailable() bool { + cmdPath := e.getCommandPath() + _, err := exec.LookPath(cmdPath) + return err == nil +} + +// Capabilities returns engine capabilities. +func (e *CLIEngine) Capabilities() Capabilities { + models := []string{e.model} + if e.largeModel != "" && e.largeModel != e.model { + models = append(models, e.largeModel) + } + + return Capabilities{ + SupportsTemperature: e.provider.SupportsTemperature, + SupportsMaxTokens: e.provider.SupportsMaxTokens, + SupportsComplexity: e.largeModel != "", + SupportsStreaming: false, + MaxContextLength: 0, + Models: models, + } +} + +// Execute sends the request via CLI. +func (e *CLIEngine) Execute(ctx context.Context, req *Request) (string, error) { + model := e.model + if req.Complexity >= ComplexityHigh && e.largeModel != "" { + model = e.largeModel + } + + prompt := req.CombinedPrompt() + args := e.provider.BuildArgs(model, prompt) + args = e.appendOptionalFlags(args, req) + + if e.verbose { + fmt.Fprintf(os.Stderr, "CLI Engine (%s) request:\n Model: %s\n Complexity: %s\n Prompt length: %d chars\n", + e.provider.Type, model, req.Complexity, len(prompt)) + } + + cmdCtx, cancel := context.WithTimeout(ctx, e.timeout) + defer cancel() + + cmdPath := e.getCommandPath() + cmd := exec.CommandContext(cmdCtx, cmdPath, args...) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + if cmdCtx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("CLI command timed out after %v", e.timeout) + } + return "", fmt.Errorf("CLI command failed: %w\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + + response, err := e.provider.ParseResponse(stdout.Bytes()) + if err != nil { + return "", fmt.Errorf("failed to parse CLI response: %w", err) + } + + if e.verbose { + fmt.Fprintf(os.Stderr, "CLI Engine (%s) response:\n Content length: %d chars\n", + e.provider.Type, len(response)) + } + + return response, nil +} + +// getCommandPath returns the path to the CLI executable. +func (e *CLIEngine) getCommandPath() string { + if e.customPath != "" { + return e.customPath + } + return e.provider.Command +} + +// appendOptionalFlags adds optional flags based on request parameters. +func (e *CLIEngine) appendOptionalFlags(args []string, req *Request) []string { + if e.provider.SupportsMaxTokens && e.provider.MaxTokensFlag != "" && req.MaxTokens > 0 { + args = append(args, e.provider.MaxTokensFlag, fmt.Sprintf("%d", req.MaxTokens)) + } + + if e.provider.SupportsTemperature && e.provider.TemperatureFlag != "" && req.Temperature > 0 { + args = append(args, e.provider.TemperatureFlag, fmt.Sprintf("%.2f", req.Temperature)) + } + + return args +} + +// GetProvider returns the underlying provider. +func (e *CLIEngine) GetProvider() *CLIProvider { + return e.provider +} + +// GetModel returns the current model. +func (e *CLIEngine) GetModel() string { + return e.model +} + +// SetVerbose sets verbose mode. +func (e *CLIEngine) SetVerbose(verbose bool) { + e.verbose = verbose +} diff --git a/internal/llm/engine/cli_test.go b/internal/llm/engine/cli_test.go new file mode 100644 index 0000000..bfe3dfe --- /dev/null +++ b/internal/llm/engine/cli_test.go @@ -0,0 +1,77 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCLIEngine(t *testing.T) { + t.Run("valid provider", func(t *testing.T) { + engine, err := NewCLIEngine(ProviderClaude) + require.NoError(t, err) + assert.NotNil(t, engine) + assert.Equal(t, "cli-claude", engine.Name()) + }) + + t.Run("with options", func(t *testing.T) { + engine, err := NewCLIEngine( + ProviderClaude, + WithCLIModel("custom-model"), + WithCLILargeModel("large-model"), + WithCLIVerbose(true), + ) + require.NoError(t, err) + assert.Equal(t, "custom-model", engine.GetModel()) + }) + + t.Run("invalid provider", func(t *testing.T) { + _, err := NewCLIEngine(CLIProviderType("invalid")) + assert.Error(t, err) + }) +} + +func TestCLIEngine_Capabilities(t *testing.T) { + engine, err := NewCLIEngine(ProviderClaude) + require.NoError(t, err) + + caps := engine.Capabilities() + + assert.False(t, caps.SupportsMaxTokens) + assert.False(t, caps.SupportsStreaming) + assert.True(t, caps.SupportsComplexity) // Has LargeModel + assert.NotEmpty(t, caps.Models) +} + +func TestDetectAvailableCLIs(t *testing.T) { + clis := DetectAvailableCLIs() + + // Should return info for all supported providers + assert.Len(t, clis, 2) + + // Each CLI should have provider and name set + for _, cli := range clis { + assert.NotEmpty(t, cli.Provider) + assert.NotEmpty(t, cli.Name) + } +} + +func TestGetProviderByCommand(t *testing.T) { + t.Run("claude command", func(t *testing.T) { + provider, err := GetProviderByCommand("claude") + require.NoError(t, err) + assert.Equal(t, ProviderClaude, provider.Type) + }) + + t.Run("gemini command", func(t *testing.T) { + provider, err := GetProviderByCommand("gemini") + require.NoError(t, err) + assert.Equal(t, ProviderGemini, provider.Type) + }) + + t.Run("unknown command", func(t *testing.T) { + _, err := GetProviderByCommand("unknown") + assert.Error(t, err) + }) +} diff --git a/internal/llm/engine/cliprovider/claude.go b/internal/llm/engine/cliprovider/claude.go new file mode 100644 index 0000000..4961c98 --- /dev/null +++ b/internal/llm/engine/cliprovider/claude.go @@ -0,0 +1,30 @@ +package cliprovider + +import "strings" + +func newClaudeProvider() *Provider { + return &Provider{ + Type: TypeClaude, + DisplayName: "Claude CLI", + Command: "claude", + DefaultModel: "claude-haiku-4-5-20251001", + LargeModel: "claude-sonnet-4-5-20250929", + BuildArgs: func(model string, prompt string) []string { + args := []string{ + "-p", prompt, + "--output-format", "text", + } + if model != "" { + args = append(args, "--model", model) + } + return args + }, + ParseResponse: func(output []byte) (string, error) { + return strings.TrimSpace(string(output)), nil + }, + SupportsMaxTokens: false, + MaxTokensFlag: "", + SupportsTemperature: false, + TemperatureFlag: "", + } +} diff --git a/internal/llm/engine/cliprovider/gemini.go b/internal/llm/engine/cliprovider/gemini.go new file mode 100644 index 0000000..f299a8e --- /dev/null +++ b/internal/llm/engine/cliprovider/gemini.go @@ -0,0 +1,27 @@ +package cliprovider + +import "strings" + +func newGeminiProvider() *Provider { + return &Provider{ + Type: TypeGemini, + DisplayName: "Gemini CLI", + Command: "gemini", + DefaultModel: "gemini-2.0-flash", + LargeModel: "gemini-2.5-pro-preview-06-05", + BuildArgs: func(model string, prompt string) []string { + return []string{ + "prompt", + "-m", model, + prompt, + } + }, + ParseResponse: func(output []byte) (string, error) { + return strings.TrimSpace(string(output)), nil + }, + SupportsMaxTokens: true, + MaxTokensFlag: "--max-tokens", + SupportsTemperature: true, + TemperatureFlag: "--temperature", + } +} diff --git a/internal/llm/engine/cliprovider/provider.go b/internal/llm/engine/cliprovider/provider.go new file mode 100644 index 0000000..7922d3f --- /dev/null +++ b/internal/llm/engine/cliprovider/provider.go @@ -0,0 +1,141 @@ +package cliprovider + +import ( + "fmt" + "os/exec" + "strings" +) + +// Type represents supported CLI provider types. +type Type string + +const ( + // TypeClaude is the Claude CLI provider. + TypeClaude Type = "claude" + // TypeGemini is the Gemini CLI provider. + TypeGemini Type = "gemini" +) + +// IsValid checks if the provider type is valid. +func (t Type) IsValid() bool { + switch t { + case TypeClaude, TypeGemini: + return true + default: + return false + } +} + +// Provider defines how to interact with a specific CLI tool. +type Provider struct { + // Type is the provider identifier. + Type Type + + // DisplayName is the human-readable name. + DisplayName string + + // Command is the executable name or path. + Command string + + // DefaultModel is the default model to use. + DefaultModel string + + // LargeModel is the model for high complexity tasks (optional). + LargeModel string + + // BuildArgs constructs CLI arguments for the given request. + BuildArgs func(model string, prompt string) []string + + // ParseResponse extracts text from CLI output. + ParseResponse func(output []byte) (string, error) + + // SupportsMaxTokens indicates if --max-tokens or similar is supported. + SupportsMaxTokens bool + + // MaxTokensFlag is the flag name for max tokens (e.g., "--max-tokens"). + MaxTokensFlag string + + // SupportsTemperature indicates if temperature is supported. + SupportsTemperature bool + + // TemperatureFlag is the flag name for temperature. + TemperatureFlag string +} + +// Info represents detected CLI information. +type Info struct { + Provider Type + Name string + Path string + Version string + Available bool +} + +// Supported returns all supported CLI providers. +func Supported() map[Type]*Provider { + return map[Type]*Provider{ + TypeClaude: newClaudeProvider(), + TypeGemini: newGeminiProvider(), + } +} + +// Get returns the provider for the given type. +func Get(providerType Type) (*Provider, error) { + providers := Supported() + provider, ok := providers[providerType] + if !ok { + return nil, fmt.Errorf("unsupported CLI provider: %s", providerType) + } + return provider, nil +} + +// Detect scans for installed CLI tools. +func Detect() []Info { + var results []Info + + providers := Supported() + for providerType, provider := range providers { + info := Info{ + Provider: providerType, + Name: provider.DisplayName, + Available: false, + } + + path, err := exec.LookPath(provider.Command) + if err == nil { + info.Path = path + info.Available = true + info.Version = getProviderVersion(provider) + } + + results = append(results, info) + } + + return results +} + +// GetByCommand finds a provider by its command name. +func GetByCommand(command string) (*Provider, error) { + providers := Supported() + for _, provider := range providers { + if provider.Command == command { + return provider, nil + } + } + return nil, fmt.Errorf("no provider found for command: %s", command) +} + +func getProviderVersion(provider *Provider) string { + cmd := exec.Command(provider.Command, "--version") // #nosec G204 + output, err := cmd.Output() + if err != nil { + return "" + } + + lines := strings.Split(strings.TrimSpace(string(output)), "\n") + if len(lines) > 0 { + return strings.TrimSpace(lines[0]) + } + + return "" +} diff --git a/internal/llm/engine/cliprovider/provider_test.go b/internal/llm/engine/cliprovider/provider_test.go new file mode 100644 index 0000000..48e637e --- /dev/null +++ b/internal/llm/engine/cliprovider/provider_test.go @@ -0,0 +1,120 @@ +package cliprovider + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestType_IsValid(t *testing.T) { + tests := []struct { + name string + typ Type + want bool + }{ + {"claude", TypeClaude, true}, + {"gemini", TypeGemini, true}, + {"invalid", Type("invalid"), false}, + {"empty", Type(""), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.typ.IsValid()) + }) + } +} + +func TestSupported(t *testing.T) { + providers := Supported() + + assert.Len(t, providers, 2) + assert.Contains(t, providers, TypeClaude) + assert.Contains(t, providers, TypeGemini) +} + +func TestGet(t *testing.T) { + t.Run("claude", func(t *testing.T) { + provider, err := Get(TypeClaude) + require.NoError(t, err) + assert.Equal(t, TypeClaude, provider.Type) + assert.Equal(t, "Claude CLI", provider.DisplayName) + assert.Equal(t, "claude", provider.Command) + }) + + t.Run("gemini", func(t *testing.T) { + provider, err := Get(TypeGemini) + require.NoError(t, err) + assert.Equal(t, TypeGemini, provider.Type) + assert.Equal(t, "Gemini CLI", provider.DisplayName) + assert.Equal(t, "gemini", provider.Command) + }) + + t.Run("invalid", func(t *testing.T) { + _, err := Get(Type("invalid")) + assert.Error(t, err) + }) +} + +func TestBuildArgs(t *testing.T) { + t.Run("claude", func(t *testing.T) { + provider := newClaudeProvider() + args := provider.BuildArgs("claude-3-opus", "Hello!") + + assert.Contains(t, args, "-p") + assert.Contains(t, args, "Hello!") + assert.Contains(t, args, "--model") + assert.Contains(t, args, "claude-3-opus") + }) + + t.Run("gemini", func(t *testing.T) { + provider := newGeminiProvider() + args := provider.BuildArgs("gemini-pro", "Hello!") + + assert.Contains(t, args, "prompt") + assert.Contains(t, args, "-m") + assert.Contains(t, args, "gemini-pro") + }) +} + +func TestParseResponse(t *testing.T) { + providers := Supported() + + for typ, provider := range providers { + t.Run(string(typ), func(t *testing.T) { + resp, err := provider.ParseResponse([]byte(" trimmed response \n")) + require.NoError(t, err) + assert.Equal(t, "trimmed response", resp) + }) + } +} + +func TestDetect(t *testing.T) { + info := Detect() + assert.Len(t, info, 2) + + for _, cli := range info { + assert.NotEmpty(t, cli.Provider) + assert.NotEmpty(t, cli.Name) + } +} + +func TestGetByCommand(t *testing.T) { + t.Run("claude", func(t *testing.T) { + provider, err := GetByCommand("claude") + require.NoError(t, err) + assert.Equal(t, TypeClaude, provider.Type) + }) + + t.Run("gemini", func(t *testing.T) { + provider, err := GetByCommand("gemini") + require.NoError(t, err) + assert.Equal(t, TypeGemini, provider.Type) + }) + + t.Run("invalid", func(t *testing.T) { + _, err := GetByCommand("unknown") + assert.Error(t, err) + }) +} diff --git a/internal/llm/engine/engine.go b/internal/llm/engine/engine.go new file mode 100644 index 0000000..99756cc --- /dev/null +++ b/internal/llm/engine/engine.go @@ -0,0 +1,113 @@ +package engine + +import "context" + +// Complexity represents task complexity hint (engine-agnostic). +// This allows callers to express intent without coupling to specific engine features. +type Complexity int + +const ( + // ComplexityMinimal is for trivial lookups or boilerplate prompts. + ComplexityMinimal Complexity = iota + // ComplexityLow is for simple transformations, parsing, basic formatting. + ComplexityLow + // ComplexityMedium is for analysis, routing decisions, moderate reasoning. + ComplexityMedium + // ComplexityHigh is for complex reasoning, code generation, deep analysis. + ComplexityHigh +) + +// String returns human-readable complexity name. +func (c Complexity) String() string { + switch c { + case ComplexityMinimal: + return "minimal" + case ComplexityLow: + return "low" + case ComplexityMedium: + return "medium" + case ComplexityHigh: + return "high" + default: + return "unknown" + } +} + +// Request represents an engine-agnostic LLM request. +// All engines receive this unified request format and interpret it according to their capabilities. +type Request struct { + SystemPrompt string + UserPrompt string + MaxTokens int + Temperature float64 + Complexity Complexity +} + +// CombinedPrompt returns system and user prompts combined. +func (r *Request) CombinedPrompt() string { + if r.SystemPrompt == "" { + return r.UserPrompt + } + return r.SystemPrompt + "\n\n" + r.UserPrompt +} + +// LLMEngine is the interface for LLM execution engines. +type LLMEngine interface { + // Execute sends request and returns response text. + Execute(ctx context.Context, req *Request) (string, error) + + // Name returns engine identifier. + Name() string + + // IsAvailable checks if this engine can currently be used. + IsAvailable() bool + + // Capabilities returns what features this engine supports. + Capabilities() Capabilities +} + +// Capabilities describes what features an engine supports. +// This enables graceful degradation when features aren't available. +type Capabilities struct { + // SupportsTemperature indicates if temperature parameter is respected. + SupportsTemperature bool + + // SupportsMaxTokens indicates if max_tokens parameter is respected. + SupportsMaxTokens bool + + // SupportsComplexity indicates if complexity hint affects model selection. + SupportsComplexity bool + + // SupportsStreaming indicates if streaming responses are supported. + SupportsStreaming bool + + // MaxContextLength is the maximum input context length (0 = unknown). + MaxContextLength int + + // Models lists available models for this engine. + Models []string +} + +// Mode represents the preferred engine selection mode. +type Mode string + +const ( + // ModeAuto automatically selects the best available engine. + ModeAuto Mode = "auto" + // ModeMCP forces MCP sampling engine. + ModeMCP Mode = "mcp" + // ModeCLI forces CLI engine. + ModeCLI Mode = "cli" + // ModeAPI forces API engine. + ModeAPI Mode = "api" +) + +// IsValid checks if the engine mode is valid. +func (m Mode) IsValid() bool { + switch m { + case ModeAuto, ModeMCP, ModeCLI, ModeAPI: + return true + default: + return false + } +} diff --git a/internal/llm/engine/engine_test.go b/internal/llm/engine/engine_test.go new file mode 100644 index 0000000..2954b45 --- /dev/null +++ b/internal/llm/engine/engine_test.go @@ -0,0 +1,98 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestComplexity_String(t *testing.T) { + tests := []struct { + name string + complexity Complexity + want string + }{ + {"low", ComplexityLow, "low"}, + {"medium", ComplexityMedium, "medium"}, + {"high", ComplexityHigh, "high"}, + {"unknown", Complexity(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.complexity.String()) + }) + } +} + +func TestRequest_CombinedPrompt(t *testing.T) { + tests := []struct { + name string + req Request + want string + }{ + { + name: "with system and user prompt", + req: Request{ + SystemPrompt: "You are a helpful assistant.", + UserPrompt: "Hello!", + }, + want: "You are a helpful assistant.\n\nHello!", + }, + { + name: "only user prompt", + req: Request{ + SystemPrompt: "", + UserPrompt: "Hello!", + }, + want: "Hello!", + }, + { + name: "empty prompts", + req: Request{ + SystemPrompt: "", + UserPrompt: "", + }, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.req.CombinedPrompt()) + }) + } +} + +func TestMode_IsValid(t *testing.T) { + tests := []struct { + name string + mode Mode + valid bool + }{ + {"auto", ModeAuto, true}, + {"mcp", ModeMCP, true}, + {"cli", ModeCLI, true}, + {"api", ModeAPI, true}, + {"invalid", Mode("invalid"), false}, + {"empty", Mode(""), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.valid, tt.mode.IsValid()) + }) + } +} + +func TestCapabilities_Default(t *testing.T) { + caps := Capabilities{} + + assert.False(t, caps.SupportsTemperature) + assert.False(t, caps.SupportsMaxTokens) + assert.False(t, caps.SupportsComplexity) + assert.False(t, caps.SupportsStreaming) + assert.Equal(t, 0, caps.MaxContextLength) + assert.Nil(t, caps.Models) +} + diff --git a/internal/llm/engine/mcp.go b/internal/llm/engine/mcp.go new file mode 100644 index 0000000..909f231 --- /dev/null +++ b/internal/llm/engine/mcp.go @@ -0,0 +1,116 @@ +package engine + +import ( + "context" + "fmt" + "os" + + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPEngine implements LLMEngine interface for MCP sampling. +// It delegates LLM calls to the host application via MCP's CreateMessage. +type MCPEngine struct { + session *mcpsdk.ServerSession + verbose bool +} + +// MCPEngineOption is a functional option for MCPEngine. +type MCPEngineOption func(*MCPEngine) + +// WithMCPVerbose enables verbose logging. +func WithMCPVerbose(verbose bool) MCPEngineOption { + return func(e *MCPEngine) { e.verbose = verbose } +} + +// NewMCPEngine creates a new MCP sampling engine. +func NewMCPEngine(session *mcpsdk.ServerSession, opts ...MCPEngineOption) *MCPEngine { + e := &MCPEngine{ + session: session, + verbose: false, + } + + for _, opt := range opts { + opt(e) + } + + return e +} + +// Name returns the engine identifier. +func (e *MCPEngine) Name() string { + return "mcp-sampling" +} + +// IsAvailable checks if the engine can be used. +func (e *MCPEngine) IsAvailable() bool { + return e.session != nil +} + +// Capabilities returns engine capabilities. +// MCP sampling capabilities depend on the host LLM, so we're conservative here. +func (e *MCPEngine) Capabilities() Capabilities { + return Capabilities{ + SupportsTemperature: false, // Host decides + SupportsMaxTokens: true, // Passed to CreateMessage + SupportsComplexity: false, // Host decides model + SupportsStreaming: false, // Not implemented + MaxContextLength: 0, // Unknown + Models: nil, // Host decides + } +} + +// Execute sends the request via MCP sampling. +func (e *MCPEngine) Execute(ctx context.Context, req *Request) (string, error) { + if e.session == nil { + return "", fmt.Errorf("MCP session not available") + } + + if e.verbose { + fmt.Fprintf(os.Stderr, "MCP Sampling request:\n MaxTokens: %d\n Prompt length: %d chars\n", + req.MaxTokens, len(req.UserPrompt)) + } + + maxTokens := req.MaxTokens + if maxTokens == 0 { + maxTokens = defaultAPIMaxTokens + } + + result, err := e.session.CreateMessage(ctx, &mcpsdk.CreateMessageParams{ + Messages: []*mcpsdk.SamplingMessage{ + { + Role: "user", + Content: &mcpsdk.TextContent{Text: req.CombinedPrompt()}, + }, + }, + MaxTokens: int64(maxTokens), + }) + if err != nil { + return "", fmt.Errorf("MCP sampling failed: %w", err) + } + + var response string + if textContent, ok := result.Content.(*mcpsdk.TextContent); ok { + response = textContent.Text + } else { + return "", fmt.Errorf("unexpected content type from MCP sampling") + } + + if e.verbose { + fmt.Fprintf(os.Stderr, "MCP Sampling response:\n Model: %s\n Content length: %d chars\n", + result.Model, len(response)) + } + + return response, nil +} + +// GetSession returns the underlying MCP session. +func (e *MCPEngine) GetSession() *mcpsdk.ServerSession { + return e.session +} + +// SetVerbose sets verbose mode. +func (e *MCPEngine) SetVerbose(verbose bool) { + e.verbose = verbose +} + diff --git a/internal/mcp/server.go b/internal/mcp/server.go index d2f927a..1a41c37 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -34,14 +34,9 @@ func ConvertPolicyWithLLM(userPolicyPath, codePolicyPath string) error { return fmt.Errorf("failed to parse user policy: %w", err) } - // Setup LLM client - apiKey := envutil.GetAPIKey("OPENAI_API_KEY") - if apiKey == "" { - return fmt.Errorf("OPENAI_API_KEY not found in environment or .sym/.env") - } - - llmClient := llm.NewClient(apiKey, - llm.WithTimeout(30*time.Second), + // Setup LLM client (backend auto-selection via @llm) + llmClient := llm.NewClient( + llm.WithTimeout(30 * time.Second), ) // Create converter with output directory @@ -479,19 +474,12 @@ func (s *Server) handleValidateCode(ctx context.Context, session *sdkmcp.ServerS var llmClient *llm.Client if session != nil { // MCP mode: use host LLM via sampling - llmClient = llm.NewClient("", llm.WithMCPSession(session)) + llmClient = llm.NewClient(llm.WithMCPSession(session)) fmt.Fprintf(os.Stderr, "โœ“ Using host LLM via MCP sampling\n") } else { - // API mode: use OpenAI API directly - apiKey := envutil.GetAPIKey("OPENAI_API_KEY") - if apiKey == "" { - return nil, &RPCError{ - Code: -32000, - Message: "OPENAI_API_KEY not found in environment or .sym/.env", - } - } - llmClient = llm.NewClient(apiKey) - fmt.Fprintf(os.Stderr, "โœ“ Using OpenAI API directly\n") + // Auto mode: use configured LLM backend (CLI/API) + llmClient = llm.NewClient() + fmt.Fprintf(os.Stderr, "โœ“ Using configured LLM backend\n") } // Create unified validator that handles all engines + RBAC diff --git a/internal/server/server.go b/internal/server/server.go index d7f52bd..b6947b3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -667,17 +667,9 @@ func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) { // Determine output directory (same as input file) outputDir := filepath.Dir(policyPath) - // Get API key - apiKey, err := s.getAPIKey() - if err != nil { - fmt.Printf("Warning: %v, conversion may be limited\n", err) - apiKey = "" - } - - // Setup LLM client + // Setup LLM client (backend auto-selection via @llm) timeout := 30 * time.Second llmClient := llm.NewClient( - apiKey, llm.WithTimeout(timeout), ) @@ -719,12 +711,3 @@ func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(result) } - -// getAPIKey retrieves the OpenAI API key from environment or .sym/.env -func (s *Server) getAPIKey() (string, error) { - key := envutil.GetAPIKey("OPENAI_API_KEY") - if key == "" { - return "", fmt.Errorf("OPENAI_API_KEY not found in environment or .sym/.env") - } - return key, nil -} diff --git a/tests/e2e/full_workflow_test.go b/tests/e2e/full_workflow_test.go index 6c1a418..696eb28 100644 --- a/tests/e2e/full_workflow_test.go +++ b/tests/e2e/full_workflow_test.go @@ -74,8 +74,7 @@ func TestE2E_FullWorkflow(t *testing.T) { t.Log("STEP 2: Converting user policy using LLM") client := llm.NewClient( - apiKey, - llm.WithTimeout(30*time.Second), + llm.WithTimeout(30 * time.Second), ) outputDir := filepath.Join(testDir, ".sym") @@ -333,7 +332,7 @@ func TestE2E_CodeGenerationFeedbackLoop(t *testing.T) { }, } - client := llm.NewClient(apiKey) + client := llm.NewClient() v := validator.NewLLMValidator(client, policy) ctx := context.Background() diff --git a/tests/e2e/mcp_integration_test.go b/tests/e2e/mcp_integration_test.go index 6bb0187..3767d7e 100644 --- a/tests/e2e/mcp_integration_test.go +++ b/tests/e2e/mcp_integration_test.go @@ -142,8 +142,7 @@ func TestMCP_ValidateAIGeneratedCode(t *testing.T) { // Create LLM client client := llm.NewClient( - apiKey, - llm.WithTimeout(30*time.Second), + llm.WithTimeout(30 * time.Second), ) // Create validator @@ -379,7 +378,7 @@ func TestMCP_EndToEndWorkflow(t *testing.T) { // Step 4: Validate generated code t.Log("STEP 4: Validating AI-generated code") - client := llm.NewClient(apiKey) + client := llm.NewClient() v := validator.NewLLMValidator(client, policy) result, err := v.Validate(context.Background(), []validator.GitChange{ diff --git a/tests/e2e/validator_test.go b/tests/e2e/validator_test.go index f9643dc..23e890a 100644 --- a/tests/e2e/validator_test.go +++ b/tests/e2e/validator_test.go @@ -31,7 +31,7 @@ func TestE2E_ValidatorWithPolicy(t *testing.T) { require.NotEmpty(t, policy.Rules, "Policy should have rules") // Create LLM client - client := llm.NewClient(apiKey) + client := llm.NewClient() // Create validator v := validator.NewLLMValidator(client, policy) @@ -83,7 +83,7 @@ func TestE2E_ValidatorWithGoodCode(t *testing.T) { require.NoError(t, err) // Create LLM client - client := llm.NewClient(apiKey) + client := llm.NewClient() // Create validator v := validator.NewLLMValidator(client, policy) @@ -182,7 +182,7 @@ func TestE2E_ValidatorFilter(t *testing.T) { require.NoError(t, err) // Create LLM client - client := llm.NewClient(apiKey) + client := llm.NewClient() // Create validator v := validator.NewLLMValidator(client, policy)