diff --git a/go.mod b/go.mod index a28b7b9..c7b8adf 100644 --- a/go.mod +++ b/go.mod @@ -8,20 +8,25 @@ require ( ) require ( - github.com/manifoldco/promptui v0.9.0 + github.com/AlecAivazis/survey/v2 v2.3.7 github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/stretchr/testify v1.11.1 + golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 ) require ( - github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/jsonschema-go v0.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/mattn/go-colorable v0.1.2 // indirect + github.com/mattn/go-isatty v0.0.8 // indirect + github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c34c37e..7fd0ae2 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,29 @@ -github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= +github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= +github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= -github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= @@ -27,19 +36,48 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/adapter/adapter.go b/internal/adapter/adapter.go index 33d09b8..7e04c30 100644 --- a/internal/adapter/adapter.go +++ b/internal/adapter/adapter.go @@ -114,8 +114,9 @@ type LinterConverter interface { // These hints are collected and included in the LLM prompt for rule routing. GetRoutingHints() []string - // ConvertRules converts user rules to native linter configuration using LLM - ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*LinterConfig, error) + // ConvertRules converts user rules to native linter configuration using LLM. + // Returns ConversionResult with per-rule success/failure tracking for fallback support. + ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*ConversionResult, error) } // LinterConfig represents a generated configuration file. @@ -124,3 +125,12 @@ type LinterConfig struct { Content []byte // File content Format string // "json", "xml", "yaml" } + +// ConversionResult contains the conversion output with per-rule tracking. +// This allows the main converter to know which rules succeeded vs failed, +// enabling fallback to llm-validator for failed rules. +type ConversionResult struct { + Config *LinterConfig // Generated config file (may be nil if all rules failed) + SuccessRules []string // Rule IDs that converted successfully + FailedRules []string // Rule IDs that couldn't be converted (fallback to llm-validator) +} diff --git a/internal/adapter/checkstyle/converter.go b/internal/adapter/checkstyle/converter.go index c8d12f2..d21abcd 100644 --- a/internal/adapter/checkstyle/converter.go +++ b/internal/adapter/checkstyle/converter.go @@ -64,15 +64,17 @@ type checkstyleConfig struct { Modules []checkstyleModule `xml:"module"` } -// ConvertRules converts user rules to Checkstyle configuration using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts user rules to Checkstyle configuration using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Convert rules in parallel type moduleResult struct { index int + ruleID string module *checkstyleModule err error } @@ -85,9 +87,10 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l go func(idx int, r schema.UserRule) { defer wg.Done() - module, err := c.convertSingleRule(ctx, r, llmClient) + module, err := c.convertSingleRule(ctx, r, provider) results <- moduleResult{ index: idx, + ruleID: r.ID, module: module, err: err, } @@ -99,23 +102,34 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l close(results) }() - // Collect modules + // Collect modules with per-rule tracking var modules []checkstyleModule - var errors []string + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) for result := range results { if result.err != nil { - errors = append(errors, fmt.Sprintf("Rule %d: %v", result.index+1, result.err)) + failedRuleIDs = append(failedRuleIDs, result.ruleID) continue } if result.module != nil { modules = append(modules, *result.module) + successRuleIDs = append(successRuleIDs, result.ruleID) + } else { + // Skipped = cannot be enforced by this linter + failedRuleIDs = append(failedRuleIDs, result.ruleID) } } + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, + } + if len(modules) == 0 { - return nil, fmt.Errorf("no rules converted: %v", errors) + return convResult, nil } // Separate modules into Checker-level and TreeWalker-level @@ -173,15 +187,17 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l ` fullContent := []byte(xmlHeader + string(content)) - return &adapter.LinterConfig{ + convResult.Config = &adapter.LinterConfig{ Filename: "checkstyle.xml", Content: fullContent, Format: "xml", - }, nil + } + + return convResult, nil } // convertSingleRule converts a single rule using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (*checkstyleModule, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (*checkstyleModule, error) { systemPrompt := `You are a Checkstyle configuration expert. Convert natural language Java coding rules to Checkstyle modules. Return ONLY a JSON object (no markdown fences): @@ -255,8 +271,9 @@ Output: userPrompt := fmt.Sprintf("Convert this Java rule to Checkstyle module:\n\n%s", rule.Say) - // Call LLM with minimal complexity - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + // Call LLM + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/eslint/converter.go b/internal/adapter/eslint/converter.go index 66decfb..2b9ee42 100644 --- a/internal/adapter/eslint/converter.go +++ b/internal/adapter/eslint/converter.go @@ -45,15 +45,17 @@ func (c *Converter) GetRoutingHints() []string { } } -// ConvertRules converts user rules to ESLint configuration using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts user rules to ESLint configuration using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Convert rules in parallel using goroutines type ruleResult struct { index int + ruleID string ruleName string config interface{} err error @@ -68,9 +70,10 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l go func(idx int, r schema.UserRule) { defer wg.Done() - ruleName, config, err := c.convertSingleRule(ctx, r, llmClient) + ruleName, config, err := c.convertSingleRule(ctx, r, provider) results <- ruleResult{ index: idx, + ruleID: r.ID, ruleName: ruleName, config: config, err: err, @@ -84,64 +87,67 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l close(results) }() - // Collect results + // Collect results with per-rule tracking eslintRules := make(map[string]interface{}) - var errors []string - skippedCount := 0 + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) for result := range results { if result.err != nil { - errors = append(errors, fmt.Sprintf("Rule %d: %v", result.index+1, result.err)) - fmt.Fprintf(os.Stderr, "⚠️ ESLint rule %d conversion error: %v\n", result.index+1, result.err) + failedRuleIDs = append(failedRuleIDs, result.ruleID) + fmt.Fprintf(os.Stderr, "⚠️ ESLint rule %s conversion error: %v\n", result.ruleID, result.err) continue } if result.ruleName != "" { eslintRules[result.ruleName] = result.config - fmt.Fprintf(os.Stderr, "✓ ESLint rule %d → %s\n", result.index+1, result.ruleName) + successRuleIDs = append(successRuleIDs, result.ruleID) + fmt.Fprintf(os.Stderr, "✓ ESLint rule %s → %s\n", result.ruleID, result.ruleName) } else { - skippedCount++ - fmt.Fprintf(os.Stderr, "⊘ ESLint rule %d skipped (cannot be enforced by ESLint)\n", result.index+1) + // Skipped = cannot be enforced by this linter, fallback to llm-validator + failedRuleIDs = append(failedRuleIDs, result.ruleID) + fmt.Fprintf(os.Stderr, "⊘ ESLint rule %s skipped (cannot be enforced by ESLint)\n", result.ruleID) } } - if skippedCount > 0 { - fmt.Fprintf(os.Stderr, "ℹ️ %d rule(s) skipped for ESLint (will use llm-validator)\n", skippedCount) + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, } - if len(eslintRules) == 0 { - return nil, fmt.Errorf("no rules converted successfully: %v", errors) - } + // Generate config only if at least one rule succeeded + if len(eslintRules) > 0 { + eslintConfig := map[string]interface{}{ + "env": map[string]bool{ + "es2021": true, + "node": true, + "browser": true, + }, + "parserOptions": map[string]interface{}{ + "ecmaVersion": "latest", + "sourceType": "module", + }, + "rules": eslintRules, + } - // Build ESLint configuration - eslintConfig := map[string]interface{}{ - "env": map[string]bool{ - "es2021": true, - "node": true, - "browser": true, - }, - "parserOptions": map[string]interface{}{ - "ecmaVersion": "latest", - "sourceType": "module", - }, - "rules": eslintRules, - } + content, err := json.MarshalIndent(eslintConfig, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } - // Marshal to JSON - content, err := json.MarshalIndent(eslintConfig, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to marshal config: %w", err) + convResult.Config = &adapter.LinterConfig{ + Filename: ".eslintrc.json", + Content: content, + Format: "json", + } } - return &adapter.LinterConfig{ - Filename: ".eslintrc.json", - Content: content, - Format: "json", - }, nil + return convResult, nil } // convertSingleRule converts a single user rule to ESLint rule using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (string, interface{}, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (string, interface{}, error) { systemPrompt := `You are an ESLint configuration expert. Convert natural language coding rules to ESLint rule configurations. Return ONLY a JSON object (no markdown fences) with this structure: @@ -217,8 +223,9 @@ Output: userPrompt += fmt.Sprintf("\nSeverity: %s", rule.Severity) } - // Call LLM with minimal complexity - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + // Call LLM + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return "", nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/pmd/converter.go b/internal/adapter/pmd/converter.go index 55a76e5..334eddd 100644 --- a/internal/adapter/pmd/converter.go +++ b/internal/adapter/pmd/converter.go @@ -63,17 +63,19 @@ type pmdRule struct { Priority int `xml:"priority,omitempty"` } -// ConvertRules converts user rules to PMD configuration using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts user rules to PMD configuration using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Convert rules in parallel type ruleResult struct { - index int - rule *pmdRule - err error + index int + ruleID string + rule *pmdRule + err error } results := make(chan ruleResult, len(rules)) @@ -84,11 +86,12 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l go func(idx int, r schema.UserRule) { defer wg.Done() - pmdRule, err := c.convertSingleRule(ctx, r, llmClient) + pmdRule, err := c.convertSingleRule(ctx, r, provider) results <- ruleResult{ - index: idx, - rule: pmdRule, - err: err, + index: idx, + ruleID: r.ID, + rule: pmdRule, + err: err, } }(i, rule) } @@ -98,23 +101,34 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l close(results) }() - // Collect rules + // Collect rules with per-rule tracking var pmdRules []pmdRule - var errors []string + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) for result := range results { if result.err != nil { - errors = append(errors, fmt.Sprintf("Rule %d: %v", result.index+1, result.err)) + failedRuleIDs = append(failedRuleIDs, result.ruleID) continue } if result.rule != nil { pmdRules = append(pmdRules, *result.rule) + successRuleIDs = append(successRuleIDs, result.ruleID) + } else { + // Skipped = cannot be enforced by this linter + failedRuleIDs = append(failedRuleIDs, result.ruleID) } } + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, + } + if len(pmdRules) == 0 { - return nil, fmt.Errorf("no rules converted: %v", errors) + return convResult, nil } // Build PMD ruleset @@ -136,15 +150,17 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l xmlHeader := `` + "\n" fullContent := []byte(xmlHeader + string(content)) - return &adapter.LinterConfig{ + convResult.Config = &adapter.LinterConfig{ Filename: "pmd.xml", Content: fullContent, Format: "xml", - }, nil + } + + return convResult, nil } // convertSingleRule converts a single rule using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (*pmdRule, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (*pmdRule, error) { systemPrompt := `You are a PMD 7.x configuration expert. Convert natural language Java coding rules to PMD rule references. Return ONLY a JSON object with exactly these two fields (no other fields): @@ -188,8 +204,9 @@ IMPORTANT: Return ONLY the JSON object. Do NOT include description, message, or userPrompt := fmt.Sprintf("Convert this Java rule to PMD rule reference:\n\n%s", rule.Say) - // Call LLM with minimal complexity - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + // Call LLM + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/prettier/converter.go b/internal/adapter/prettier/converter.go index 2d951f9..5c21f0c 100644 --- a/internal/adapter/prettier/converter.go +++ b/internal/adapter/prettier/converter.go @@ -43,10 +43,11 @@ func (c *Converter) GetRoutingHints() []string { } } -// ConvertRules converts formatting rules to Prettier config using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts formatting rules to Prettier config using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Start with default Prettier configuration @@ -60,33 +61,56 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l "arrowParens": "always", } + // Track rule conversion results + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) + // Use LLM to infer settings from rules for _, rule := range rules { - config, err := c.convertSingleRule(ctx, rule, llmClient) + config, err := c.convertSingleRule(ctx, rule, provider) if err != nil { - continue // Skip rules that cannot be converted + failedRuleIDs = append(failedRuleIDs, rule.ID) + continue + } + + // Check if LLM returned empty config (rule cannot be enforced by Prettier) + if len(config) == 0 { + failedRuleIDs = append(failedRuleIDs, rule.ID) + continue } // Merge LLM-generated config for key, value := range config { prettierConfig[key] = value } + successRuleIDs = append(successRuleIDs, rule.ID) } - content, err := json.MarshalIndent(prettierConfig, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to marshal config: %w", err) + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, + } + + // Generate config only if at least one rule succeeded + if len(successRuleIDs) > 0 { + content, err := json.MarshalIndent(prettierConfig, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } + + convResult.Config = &adapter.LinterConfig{ + Filename: ".prettierrc", + Content: content, + Format: "json", + } } - return &adapter.LinterConfig{ - Filename: ".prettierrc", - Content: content, - Format: "json", - }, nil + return convResult, nil } // convertSingleRule converts a single user rule to Prettier config using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (map[string]interface{}, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (map[string]interface{}, error) { systemPrompt := `You are a Prettier configuration expert. Convert natural language formatting rules to Prettier configuration options. Return ONLY a JSON object (no markdown fences) with Prettier options. @@ -134,7 +158,8 @@ Output: userPrompt := fmt.Sprintf("Convert this rule to Prettier configuration:\n\n%s", rule.Say) // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/pylint/converter.go b/internal/adapter/pylint/converter.go index e07ece1..9c7ac3f 100644 --- a/internal/adapter/pylint/converter.go +++ b/internal/adapter/pylint/converter.go @@ -49,15 +49,17 @@ func (c *Converter) GetRoutingHints() []string { } } -// ConvertRules converts user rules to Pylint configuration using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts user rules to Pylint configuration using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Convert rules in parallel using goroutines type ruleResult struct { index int + ruleID string symbol string options map[string]interface{} err error @@ -72,9 +74,10 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l go func(idx int, r schema.UserRule) { defer wg.Done() - symbol, options, err := c.convertSingleRule(ctx, r, llmClient) + symbol, options, err := c.convertSingleRule(ctx, r, provider) results <- ruleResult{ index: idx, + ruleID: r.ID, symbol: symbol, options: options, err: err, @@ -88,21 +91,22 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l close(results) }() - // Collect results + // Collect results with per-rule tracking enabledRules := make([]string, 0) options := make(map[string]map[string]interface{}) - var errors []string - skippedCount := 0 + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) for result := range results { if result.err != nil { - errors = append(errors, fmt.Sprintf("Rule %d: %v", result.index+1, result.err)) - fmt.Fprintf(os.Stderr, "⚠️ Pylint rule %d conversion error: %v\n", result.index+1, result.err) + failedRuleIDs = append(failedRuleIDs, result.ruleID) + fmt.Fprintf(os.Stderr, "⚠️ Pylint rule %s conversion error: %v\n", result.ruleID, result.err) continue } if result.symbol != "" { enabledRules = append(enabledRules, result.symbol) + successRuleIDs = append(successRuleIDs, result.ruleID) if len(result.options) > 0 { // Group options by section for key, value := range result.options { @@ -113,33 +117,35 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l options[section][key] = value } } - fmt.Fprintf(os.Stderr, "✓ Pylint rule %d → %s\n", result.index+1, result.symbol) + fmt.Fprintf(os.Stderr, "✓ Pylint rule %s → %s\n", result.ruleID, result.symbol) } else { - skippedCount++ - fmt.Fprintf(os.Stderr, "⊘ Pylint rule %d skipped (cannot be enforced by Pylint)\n", result.index+1) + // Skipped = cannot be enforced by this linter, fallback to llm-validator + failedRuleIDs = append(failedRuleIDs, result.ruleID) + fmt.Fprintf(os.Stderr, "⊘ Pylint rule %s skipped (cannot be enforced by Pylint)\n", result.ruleID) } } - if skippedCount > 0 { - fmt.Fprintf(os.Stderr, "ℹ️ %d rule(s) skipped for Pylint (will use llm-validator)\n", skippedCount) + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, } - if len(enabledRules) == 0 { - return nil, fmt.Errorf("no rules converted successfully: %v", errors) + // Generate config only if at least one rule succeeded + if len(enabledRules) > 0 { + content := c.generatePylintRC(enabledRules, options) + convResult.Config = &adapter.LinterConfig{ + Filename: ".pylintrc", + Content: []byte(content), + Format: "ini", + } } - // Generate .pylintrc content (INI format) - content := c.generatePylintRC(enabledRules, options) - - return &adapter.LinterConfig{ - Filename: ".pylintrc", - Content: []byte(content), - Format: "ini", - }, nil + return convResult, nil } // convertSingleRule converts a single user rule to Pylint rule using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (string, map[string]interface{}, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (string, map[string]interface{}, error) { systemPrompt := `You are a Pylint configuration expert. Convert natural language Python coding rules to Pylint rule configurations. Return ONLY a JSON object (no markdown fences) with this structure: @@ -210,7 +216,8 @@ Output: } // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return "", nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/adapter/pylint/executor.go b/internal/adapter/pylint/executor.go index 6b4ae3a..1d9c368 100644 --- a/internal/adapter/pylint/executor.go +++ b/internal/adapter/pylint/executor.go @@ -37,9 +37,7 @@ func (a *Adapter) execute(ctx context.Context, config []byte, files []string) (* func (a *Adapter) getExecutionArgs(configPath string, files []string) []string { args := []string{ "--output-format=json", - "--disable=all", // Disable all checks first - "--enable=invalid-name", // Enable only naming checks (C0103) - "--rcfile=" + configPath, + "--rcfile=" + configPath, // Use .pylintrc settings as-is } args = append(args, files...) diff --git a/internal/adapter/tsc/converter.go b/internal/adapter/tsc/converter.go index 7dc4a9f..f4f2f4c 100644 --- a/internal/adapter/tsc/converter.go +++ b/internal/adapter/tsc/converter.go @@ -43,10 +43,11 @@ func (c *Converter) GetRoutingHints() []string { } } -// ConvertRules converts type-checking rules to tsconfig.json using LLM -func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, llmClient *llm.Client) (*adapter.LinterConfig, error) { - if llmClient == nil { - return nil, fmt.Errorf("LLM client is required") +// ConvertRules converts type-checking rules to tsconfig.json using LLM. +// Returns ConversionResult with per-rule success/failure tracking for fallback support. +func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, provider llm.Provider) (*adapter.ConversionResult, error) { + if provider == nil { + return nil, fmt.Errorf("LLM provider is required") } // Start with strict TypeScript configuration @@ -71,33 +72,56 @@ func (c *Converter) ConvertRules(ctx context.Context, rules []schema.UserRule, l compilerOpts := tsConfig["compilerOptions"].(map[string]interface{}) + // Track rule conversion results + successRuleIDs := make([]string, 0) + failedRuleIDs := make([]string, 0) + // Use LLM to infer settings from rules for _, rule := range rules { - config, err := c.convertSingleRule(ctx, rule, llmClient) + config, err := c.convertSingleRule(ctx, rule, provider) if err != nil { - continue // Skip rules that cannot be converted + failedRuleIDs = append(failedRuleIDs, rule.ID) + continue + } + + // Check if LLM returned empty config (rule cannot be enforced by TSC) + if len(config) == 0 { + failedRuleIDs = append(failedRuleIDs, rule.ID) + continue } // Merge LLM-generated compiler options for key, value := range config { compilerOpts[key] = value } + successRuleIDs = append(successRuleIDs, rule.ID) } - content, err := json.MarshalIndent(tsConfig, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to marshal config: %w", err) + // Build result with tracking info + convResult := &adapter.ConversionResult{ + SuccessRules: successRuleIDs, + FailedRules: failedRuleIDs, + } + + // Generate config only if at least one rule succeeded + if len(successRuleIDs) > 0 { + content, err := json.MarshalIndent(tsConfig, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } + + convResult.Config = &adapter.LinterConfig{ + Filename: "tsconfig.json", + Content: content, + Format: "json", + } } - return &adapter.LinterConfig{ - Filename: "tsconfig.json", - Content: content, - Format: "json", - }, nil + return convResult, nil } // convertSingleRule converts a single user rule to TypeScript compiler option using LLM -func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, llmClient *llm.Client) (map[string]interface{}, error) { +func (c *Converter) convertSingleRule(ctx context.Context, rule schema.UserRule, provider llm.Provider) (map[string]interface{}, error) { systemPrompt := `You are a TypeScript compiler configuration expert. Convert natural language type-checking rules to tsconfig.json compiler options. Return ONLY a JSON object (no markdown fences) with TypeScript compiler options. @@ -148,7 +172,8 @@ Output: userPrompt := fmt.Sprintf("Convert this rule to TypeScript compiler configuration:\n\n%s", rule.Say) // Call LLM - response, err := llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityMinimal).Execute(ctx) + prompt := systemPrompt + "\n\n" + userPrompt + response, err := provider.Execute(ctx, prompt, llm.JSON) if err != nil { return nil, fmt.Errorf("LLM call failed: %w", err) } diff --git a/internal/bootstrap/providers.go b/internal/bootstrap/providers.go new file mode 100644 index 0000000..7287aa2 --- /dev/null +++ b/internal/bootstrap/providers.go @@ -0,0 +1,8 @@ +package bootstrap + +import ( + // Import LLM providers for registration side-effects. + _ "github.com/DevSymphony/sym-cli/internal/llm/claudecode" + _ "github.com/DevSymphony/sym-cli/internal/llm/geminicli" + _ "github.com/DevSymphony/sym-cli/internal/llm/openaiapi" +) diff --git a/internal/cmd/api_key.go b/internal/cmd/api_key.go deleted file mode 100644 index ec5c33e..0000000 --- a/internal/cmd/api_key.go +++ /dev/null @@ -1,212 +0,0 @@ -package cmd - -import ( - "bufio" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/DevSymphony/sym-cli/internal/envutil" - "github.com/manifoldco/promptui" -) - -// promptAPIKeySetup prompts user to setup API key (without checking if it exists) -func promptAPIKeySetup() { - promptAPIKeyConfiguration(false) -} - -// promptAPIKeyConfiguration handles API key configuration with optional existence check -func promptAPIKeyConfiguration(checkExisting bool) { - envPath := filepath.Join(".sym", ".env") - - if checkExisting { - // 1. Check environment variable or .env file - if envutil.GetAPIKey("OPENAI_API_KEY") != "" { - fmt.Println("\n✓ OpenAI API key detected from environment or .sym/.env") - return - } - - // 2. Check .sym/.env file - if hasAPIKeyInEnvFile(envPath) { - fmt.Println("\n✓ OpenAI API key found in .sym/.env") - return - } - - // Neither found - show warning - fmt.Println("\n⚠ OpenAI API key not found") - fmt.Println(" (Required for convert, validate commands and MCP auto-conversion)") - fmt.Println() - } - - // Create selection prompt - items := []string{ - "Enter API key", - "Skip (set manually later)", - } - - templates := &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "✓ {{ . | green }}", - } - - selectPrompt := promptui.Select{ - Label: "Would you like to configure it now", - Items: items, - Templates: templates, - Size: 2, - } - - index, _, err := selectPrompt.Run() - if err != nil { - fmt.Println("\nSkipped API key configuration") - return - } - - switch index { - case 0: // Enter API key - apiKey, err := promptForAPIKey() - if err != nil { - fmt.Printf("\n❌ Failed to read API key: %v\n", err) - return - } - - // Validate API key format - if err := validateAPIKey(apiKey); err != nil { - fmt.Printf("\n⚠ Warning: %v\n", err) - fmt.Println(" API key was saved anyway. Make sure it's correct.") - } - - // Save to .sym/.env - if err := envutil.SaveKeyToEnvFile(envPath, "OPENAI_API_KEY", apiKey); err != nil { - fmt.Printf("\n❌ Failed to save API key: %v\n", err) - return - } - - fmt.Println("\n✓ API key saved to .sym/.env") - - // Add to .gitignore - if err := ensureGitignore(".sym/.env"); err != nil { - fmt.Printf("⚠ Warning: Failed to update .gitignore: %v\n", err) - fmt.Println(" Please manually add '.sym/.env' to .gitignore") - } else { - fmt.Println("✓ Added .sym/.env to .gitignore") - } - - case 1: // Skip - fmt.Println("\nSkipped API key configuration") - fmt.Println("\n💡 Tip: You can set OPENAI_API_KEY in:") - fmt.Println(" - .sym/.env file") - fmt.Println(" - System environment variable") - } -} - -// promptForAPIKey prompts user to enter API key -func promptForAPIKey() (string, error) { - fmt.Print("Enter your OpenAI API key: ") - - // Use bufio reader for better paste support - reader := bufio.NewReader(os.Stdin) - input, err := reader.ReadString('\n') - if err != nil { - return "", fmt.Errorf("failed to read API key: %w", err) - } - - // Clean the input: remove all whitespace, control characters, and non-printable characters - apiKey := cleanAPIKey(input) - - if len(apiKey) == 0 { - return "", fmt.Errorf("API key cannot be empty") - } - - return apiKey, nil -} - -// cleanAPIKey removes whitespace, control characters, and non-printable characters from API key -func cleanAPIKey(input string) string { - var result strings.Builder - for _, r := range input { - // Only keep printable ASCII characters (excluding space) - if r >= 33 && r <= 126 { - result.WriteRune(r) - } - } - return result.String() -} - -// validateAPIKey performs basic validation on API key format -func validateAPIKey(key string) error { - if !strings.HasPrefix(key, "sk-") { - return fmt.Errorf("API key should start with 'sk-'") - } - if len(key) < 20 { - return fmt.Errorf("API key seems too short") - } - return nil -} - -// hasAPIKeyInEnvFile checks if OPENAI_API_KEY exists in .env file -func hasAPIKeyInEnvFile(envPath string) bool { - file, err := os.Open(envPath) - if err != nil { - return false - } - defer func() { _ = file.Close() }() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "OPENAI_API_KEY=") { - parts := strings.SplitN(line, "=", 2) - if len(parts) == 2 && strings.TrimSpace(parts[1]) != "" { - return true - } - } - } - - return false -} - -// ensureGitignore ensures that the given path is in .gitignore -func ensureGitignore(path string) error { - gitignorePath := ".gitignore" - - // Read existing .gitignore - var lines []string - existingFile, err := os.Open(gitignorePath) - if err == nil { - scanner := bufio.NewScanner(existingFile) - for scanner.Scan() { - line := scanner.Text() - lines = append(lines, line) - // Check if already exists - if strings.TrimSpace(line) == path { - _ = existingFile.Close() - return nil // Already in .gitignore - } - } - _ = existingFile.Close() - } - - // Add to .gitignore - lines = append(lines, "", "# Symphony API key configuration", path) - content := strings.Join(lines, "\n") + "\n" - - if err := os.WriteFile(gitignorePath, []byte(content), 0644); err != nil { - return fmt.Errorf("failed to update .gitignore: %w", err) - } - - return nil -} - -// getAPIKey retrieves OpenAI API key from environment or .env file -// Returns error if not found -func getAPIKey() (string, error) { - key := envutil.GetAPIKey("OPENAI_API_KEY") - if key == "" { - return "", fmt.Errorf("OPENAI_API_KEY not found in environment or .sym/.env") - } - return key, nil -} diff --git a/internal/cmd/convert.go b/internal/cmd/convert.go index 1b458ad..d3f2f59 100644 --- a/internal/cmd/convert.go +++ b/internal/cmd/convert.go @@ -5,24 +5,22 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" "strings" "time" "github.com/DevSymphony/sym-cli/internal/adapter/registry" + "github.com/DevSymphony/sym-cli/internal/config" "github.com/DevSymphony/sym-cli/internal/converter" "github.com/DevSymphony/sym-cli/internal/llm" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/DevSymphony/sym-cli/pkg/schema" "github.com/spf13/cobra" ) var ( - convertInputFile string - convertOutputFile string - convertTargets []string - convertOutputDir string - convertConfidenceThreshold float64 - convertTimeout int + convertInputFile string + convertTargets []string + convertOutputDir string ) var convertCmd = &cobra.Command{ @@ -41,25 +39,17 @@ map them to appropriate linter rules.`, sym convert -i user-policy.json --targets eslint # Convert for Java with specific model - sym convert -i user-policy.json --targets checkstyle,pmd --openai-model gpt-5-mini - # Convert for Java with specific model - sym convert -i user-policy.json --targets checkstyle,pmd --openai-model gpt-5-mini + sym convert -i user-policy.json --targets checkstyle,pmd # Use custom output directory - sym convert -i user-policy.json --targets all --output-dir ./custom-dir - - # Legacy mode (internal policy only) - sym convert -i user-policy.json -o code-policy.json`, + sym convert -i user-policy.json --targets all --output-dir ./custom-dir`, RunE: runConvert, } func init() { convertCmd.Flags().StringVarP(&convertInputFile, "input", "i", "", "input user policy file (default: from .sym/.env POLICY_PATH)") - convertCmd.Flags().StringVarP(&convertOutputFile, "output", "o", "", "output code policy file (legacy mode)") convertCmd.Flags().StringSliceVar(&convertTargets, "targets", []string{}, buildTargetsDescription()) convertCmd.Flags().StringVar(&convertOutputDir, "output-dir", "", "output directory for linter configs (default: same as input file directory)") - convertCmd.Flags().Float64Var(&convertConfidenceThreshold, "confidence-threshold", 0.7, "minimum confidence for LLM inference (0.0-1.0)") - convertCmd.Flags().IntVar(&convertTimeout, "timeout", 30, "timeout for API calls in seconds") } // buildTargetsDescription dynamically builds the --targets flag description @@ -74,13 +64,14 @@ func buildTargetsDescription() string { func runConvert(cmd *cobra.Command, args []string) error { // Determine input file path if convertInputFile == "" { - // Try to load from .env - policyPath := loadPolicyPathFromEnv() + // Load from config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := projectCfg.PolicyPath if policyPath == "" { policyPath = ".sym/user-policy.json" // fallback default } convertInputFile = policyPath - fmt.Printf("Using policy path from .env: %s\n", convertInputFile) + fmt.Printf("Using policy path from config: %s\n", convertInputFile) } // Read input file @@ -100,32 +91,6 @@ func runConvert(cmd *cobra.Command, args []string) error { return runNewConverter(&userPolicy) } -// loadPolicyPathFromEnv reads POLICY_PATH from .sym/.env -func loadPolicyPathFromEnv() string { - envPath := filepath.Join(".sym", ".env") - data, err := os.ReadFile(envPath) - if err != nil { - return "" - } - - lines := strings.Split(string(data), "\n") - prefix := "POLICY_PATH=" - - for _, line := range lines { - line = strings.TrimSpace(line) - // Skip comments and empty lines - if len(line) == 0 || line[0] == '#' { - continue - } - // Check if line starts with POLICY_PATH= - if strings.HasPrefix(line, prefix) { - return strings.TrimSpace(line[len(prefix):]) - } - } - - return "" -} - func runNewConverter(userPolicy *schema.UserPolicy) error { // Determine output directory if convertOutputDir == "" { @@ -133,28 +98,23 @@ func runNewConverter(userPolicy *schema.UserPolicy) error { convertOutputDir = ".sym" } - timeout := time.Duration(convertTimeout) * time.Second - llmClient := llm.NewClient( - llm.WithTimeout(timeout), - ) - - // Ensure at least one backend is available (MCP/CLI/API) - availabilityCtx, cancelAvailability := context.WithTimeout(context.Background(), timeout) - defer cancelAvailability() - - if err := llmClient.CheckAvailability(availabilityCtx); err != nil { - return fmt.Errorf("no available LLM backend for convert: %w\nTip: run 'sym init --setup-llm' or configure LLM_BACKEND / LLM_CLI / OPENAI_API_KEY in .sym/.env", err) + // Create LLM provider + cfg := llm.LoadConfig() + llmProvider, err := llm.New(cfg) + if err != nil { + return fmt.Errorf("no available LLM backend for convert: %w\nTip: configure provider in .sym/config.json", err) } + defer llmProvider.Close() // Create new converter - conv := converter.NewConverter(llmClient, convertOutputDir) + conv := converter.NewConverter(llmProvider, convertOutputDir) - // Setup context with generous timeout for parallel processing - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(convertTimeout*10)*time.Second) + // Setup context with generous timeout for parallel processing (10 minutes to match validator) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - fmt.Printf("\n🚀 Converting with language-based routing and parallel LLM inference\n") - fmt.Printf("📂 Output: %s\n\n", convertOutputDir) + ui.PrintTitle("Convert", "Language-based routing with parallel LLM inference") + fmt.Printf("Output: %s\n\n", convertOutputDir) // Convert result, err := conv.Convert(ctx, userPolicy) @@ -163,23 +123,26 @@ func runNewConverter(userPolicy *schema.UserPolicy) error { } // Print results - fmt.Printf("\n✅ Conversion completed successfully!\n") - fmt.Printf("📦 Generated %d configuration file(s):\n", len(result.GeneratedFiles)) + fmt.Println() + ui.PrintOK("Conversion completed successfully") + fmt.Printf("Generated %d configuration file(s):\n", len(result.GeneratedFiles)) for _, file := range result.GeneratedFiles { - fmt.Printf(" ✓ %s\n", file) + fmt.Printf(" - %s\n", file) } if len(result.Errors) > 0 { - fmt.Printf("\n⚠️ Errors (%d):\n", len(result.Errors)) + fmt.Println() + ui.PrintWarn(fmt.Sprintf("Errors (%d):", len(result.Errors))) for linter, err := range result.Errors { - fmt.Printf(" ✗ %s: %v\n", linter, err) + fmt.Printf(" - %s: %v\n", linter, err) } } if len(result.Warnings) > 0 { - fmt.Printf("\n⚠️ Warnings (%d):\n", len(result.Warnings)) + fmt.Println() + ui.PrintWarn(fmt.Sprintf("Warnings (%d):", len(result.Warnings))) for _, warning := range result.Warnings { - fmt.Printf(" • %s\n", warning) + fmt.Printf(" - %s\n", warning) } } diff --git a/internal/cmd/dashboard.go b/internal/cmd/dashboard.go index 15e8501..b61fd8b 100644 --- a/internal/cmd/dashboard.go +++ b/internal/cmd/dashboard.go @@ -6,6 +6,7 @@ import ( "github.com/DevSymphony/sym-cli/internal/roles" "github.com/DevSymphony/sym-cli/internal/server" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/spf13/cobra" ) @@ -33,7 +34,7 @@ func runDashboard(cmd *cobra.Command, args []string) { // Check if roles.json exists exists, err := roles.RolesExists() if err != nil || !exists { - fmt.Println("❌ roles.json not found") + ui.PrintError("roles.json not found") fmt.Println("Run 'sym init' to create it") os.Exit(1) } @@ -41,12 +42,12 @@ func runDashboard(cmd *cobra.Command, args []string) { // Start server srv, err := server.NewServer(dashboardPort) if err != nil { - fmt.Printf("❌ Failed to create server: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to create server: %v", err)) os.Exit(1) } if err := srv.Start(); err != nil { - fmt.Printf("❌ Failed to start server: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to start server: %v", err)) os.Exit(1) } } diff --git a/internal/cmd/init.go b/internal/cmd/init.go index db11c01..3cc66d1 100644 --- a/internal/cmd/init.go +++ b/internal/cmd/init.go @@ -6,9 +6,10 @@ import ( "path/filepath" "github.com/DevSymphony/sym-cli/internal/adapter/registry" - "github.com/DevSymphony/sym-cli/internal/envutil" + "github.com/DevSymphony/sym-cli/internal/config" "github.com/DevSymphony/sym-cli/internal/policy" "github.com/DevSymphony/sym-cli/internal/roles" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/DevSymphony/sym-cli/pkg/schema" "github.com/spf13/cobra" @@ -31,8 +32,6 @@ var ( initForce bool skipMCPRegister bool registerMCPOnly bool - skipAPIKey bool - setupAPIKeyOnly bool skipLLMSetup bool setupLLMOnly bool ) @@ -41,8 +40,6 @@ func init() { initCmd.Flags().BoolVarP(&initForce, "force", "f", false, "Overwrite existing roles.json") initCmd.Flags().BoolVar(&skipMCPRegister, "skip-mcp", false, "Skip MCP server registration prompt") initCmd.Flags().BoolVar(®isterMCPOnly, "register-mcp", false, "Register MCP server only (skip roles/policy init)") - initCmd.Flags().BoolVar(&skipAPIKey, "skip-api-key", false, "Skip OpenAI API key configuration prompt (deprecated, use --skip-llm)") - initCmd.Flags().BoolVar(&setupAPIKeyOnly, "setup-api-key", false, "Setup OpenAI API key only (deprecated, use --setup-llm)") initCmd.Flags().BoolVar(&skipLLMSetup, "skip-llm", false, "Skip LLM backend configuration prompt") initCmd.Flags().BoolVar(&setupLLMOnly, "setup-llm", false, "Setup LLM backend only (skip roles/policy init)") } @@ -50,21 +47,14 @@ func init() { func runInit(cmd *cobra.Command, args []string) { // MCP registration only mode if registerMCPOnly { - fmt.Println("🔧 Registering Symphony MCP server...") + ui.PrintTitle("MCP", "Registering Symphony MCP server") promptMCPRegistration() return } - // API key setup only mode (deprecated) - if setupAPIKeyOnly { - fmt.Println("🔑 Setting up OpenAI API key...") - promptAPIKeySetup() - return - } - // LLM setup only mode if setupLLMOnly { - fmt.Println("🤖 Setting up LLM backend...") + ui.PrintTitle("LLM", "Setting up LLM backend") promptLLMBackendSetup() return } @@ -72,12 +62,12 @@ func runInit(cmd *cobra.Command, args []string) { // Check if roles.json already exists exists, err := roles.RolesExists() if err != nil { - fmt.Printf("❌ Failed to check roles.json: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to check roles.json: %v", err)) os.Exit(1) } if exists && !initForce { - fmt.Println("⚠ roles.json already exists") + ui.PrintWarn("roles.json already exists") fmt.Println("Use --force flag to overwrite") os.Exit(1) } @@ -85,7 +75,7 @@ func runInit(cmd *cobra.Command, args []string) { // If force flag is set, remove existing code-policy.json if initForce { if err := removeExistingCodePolicy(); err != nil { - fmt.Printf("⚠ Warning: Failed to remove existing code-policy.json: %v\n", err) + ui.PrintWarn(fmt.Sprintf("Failed to remove existing code-policy.json: %v", err)) } } @@ -97,36 +87,34 @@ func runInit(cmd *cobra.Command, args []string) { } if err := roles.SaveRoles(newRoles); err != nil { - fmt.Printf("❌ Failed to create roles.json: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to create roles.json: %v", err)) os.Exit(1) } rolesPath, _ := roles.GetRolesPath() - fmt.Println("✓ roles.json created successfully!") - fmt.Printf(" Location: %s\n", rolesPath) + ui.PrintOK("roles.json created") + fmt.Println(ui.Indent(fmt.Sprintf("Location: %s", rolesPath))) // Create default policy file with RBAC roles - fmt.Println("\nCreating default policy file...") if err := createDefaultPolicy(); err != nil { - fmt.Printf("⚠ Warning: Failed to create policy file: %v\n", err) - fmt.Println("You can manually create it later using the dashboard") + ui.PrintWarn(fmt.Sprintf("Failed to create policy file: %v", err)) + fmt.Println(ui.Indent("You can manually create it later using the dashboard")) } else { - fmt.Println("✓ user-policy.json created with default RBAC roles") + ui.PrintOK("user-policy.json created with default RBAC roles") } - // Create .sym/.env with default POLICY_PATH - fmt.Println("\nSetting up environment configuration...") - if err := initializeEnvFile(); err != nil { - fmt.Printf("⚠ Warning: Failed to create .sym/.env: %v\n", err) + // Create .sym/config.json with default settings + if err := initializeConfigFile(); err != nil { + ui.PrintWarn(fmt.Sprintf("Failed to create config.json: %v", err)) } else { - fmt.Println("✓ .sym/.env created with default policy path") + ui.PrintOK("config.json created") } // Set default role to admin during initialization if err := roles.SetCurrentRole("admin"); err != nil { - fmt.Printf("⚠ Warning: Failed to save role selection: %v\n", err) + ui.PrintWarn(fmt.Sprintf("Failed to save role selection: %v", err)) } else { - fmt.Println("✓ Your role has been set to: admin (default for initialization)") + ui.PrintOK("Your role has been set to: admin") } // MCP registration prompt @@ -135,23 +123,17 @@ func runInit(cmd *cobra.Command, args []string) { } // LLM backend configuration prompt - if !skipLLMSetup && !skipAPIKey { + if !skipLLMSetup { promptLLMBackendSetup() } - // Show dashboard guide after all initialization is complete - fmt.Println("\n🎯 What's Next: Use Symphony Dashboard") - fmt.Println() - fmt.Println("Start the web dashboard:") - fmt.Println(" sym dashboard") + // Show completion message fmt.Println() - fmt.Println("Dashboard features:") - fmt.Println(" 📋 Manage roles - Configure permissions for each role") - fmt.Println(" 📝 Edit policies - Create and modify coding conventions") - fmt.Println(" 🎭 Change role - Select a different role anytime") - fmt.Println(" ✅ Test validation - Check rules against your code in real-time") + ui.PrintDone("Initialization complete") fmt.Println() - fmt.Println("After setup, commit and push .sym/roles.json and .sym/user-policy.json to share with your team.") + fmt.Println("Next steps:") + fmt.Println(ui.Indent("Run 'sym dashboard' to manage roles and policies")) + fmt.Println(ui.Indent("Commit .sym/ folder to share with your team")) } // createDefaultPolicy creates a default policy file with RBAC roles @@ -205,26 +187,19 @@ func createDefaultPolicy() error { return policy.SavePolicy(defaultPolicy, defaultPolicyPath) } -// initializeEnvFile creates .sym/.env with default configuration -func initializeEnvFile() error { - envPath := filepath.Join(".sym", ".env") - defaultPolicyPath := ".sym/user-policy.json" +// initializeConfigFile creates .sym/config.json with default settings +func initializeConfigFile() error { + // Check if config.json already exists + if config.ProjectConfigExists() { + return nil + } - // Check if .env already exists - if _, err := os.Stat(envPath); err == nil { - // File exists, check if POLICY_PATH is already set - existingPath := envutil.LoadKeyFromEnvFile(envPath, "POLICY_PATH") - if existingPath != "" { - // POLICY_PATH already set, nothing to do - return nil - } - // POLICY_PATH not set, add it - return envutil.SaveKeyToEnvFile(envPath, "POLICY_PATH", defaultPolicyPath) + // Create default project config + defaultConfig := &config.ProjectConfig{ + PolicyPath: ".sym/user-policy.json", } - // .env doesn't exist, create it with default settings - content := fmt.Sprintf("# Symphony local configuration\nPOLICY_PATH=%s\nCURRENT_ROLE=admin\n", defaultPolicyPath) - return os.WriteFile(envPath, []byte(content), 0644) + return config.SaveProjectConfig(defaultConfig) } // removeExistingCodePolicy removes generated linter config files when --force flag is used @@ -240,9 +215,9 @@ func removeExistingCodePolicy() error { filePath := filepath.Join(symDir, filename) if _, err := os.Stat(filePath); err == nil { if err := os.Remove(filePath); err != nil { - fmt.Printf("⚠ Warning: Failed to remove %s: %v\n", filePath, err) + ui.PrintWarn(fmt.Sprintf("Failed to remove %s: %v", filePath, err)) } else { - fmt.Printf("✓ Removed existing %s\n", filePath) + fmt.Println(ui.Indent(fmt.Sprintf("Removed existing %s", filePath))) } } } @@ -254,7 +229,7 @@ func removeExistingCodePolicy() error { if err := os.Remove(legacyPath); err != nil { return fmt.Errorf("failed to remove %s: %w", legacyPath, err) } - fmt.Printf("✓ Removed existing %s\n", legacyPath) + fmt.Println(ui.Indent(fmt.Sprintf("Removed existing %s", legacyPath))) } return nil diff --git a/internal/cmd/llm.go b/internal/cmd/llm.go index 6b8e8e8..f8ce60c 100644 --- a/internal/cmd/llm.go +++ b/internal/cmd/llm.go @@ -1,498 +1,354 @@ package cmd import ( + "bufio" "context" "fmt" "os" "strings" "time" + "github.com/AlecAivazis/survey/v2" + "github.com/DevSymphony/sym-cli/internal/config" + "github.com/DevSymphony/sym-cli/internal/envutil" "github.com/DevSymphony/sym-cli/internal/llm" - "github.com/DevSymphony/sym-cli/internal/llm/engine" - "github.com/manifoldco/promptui" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/spf13/cobra" ) var llmCmd = &cobra.Command{ Use: "llm", - Short: "Manage LLM engine configuration", - Long: `Configure and manage LLM engines for Symphony. + Short: "Manage LLM provider configuration", + Long: `Configure and manage LLM providers for Symphony. -Symphony supports multiple LLM engines: - - MCP Sampling: Uses the host LLM when running as MCP server - - CLI: Uses local CLI tools (claude, gemini) - - API: Uses OpenAI API directly +Symphony supports multiple LLM providers: + - claudecode: Claude Code CLI (requires 'claude' in PATH) + - geminicli: Gemini CLI (requires 'gemini' in PATH) + - openaiapi: OpenAI API (requires OPENAI_API_KEY) -The default mode is 'auto' which tries engines in this order: -MCP Sampling → CLI → API`, -} - -var llmSetupCmd = &cobra.Command{ - Use: "setup", - Short: "Interactive LLM engine setup", - Long: `Interactively configure which LLM engine to use.`, - Run: runLLMSetup, +Configuration is stored in: + - .sym/config.json: Provider and model settings (safe to commit) + - .sym/.env: API keys (gitignored)`, } var llmStatusCmd = &cobra.Command{ Use: "status", - Short: "Show current LLM engine status", - Long: `Display the current LLM engine configuration and availability.`, + Short: "Show current LLM provider status", + Long: `Display the current LLM provider configuration and availability.`, Run: runLLMStatus, } var llmTestCmd = &cobra.Command{ Use: "test", - Short: "Test LLM engine connection", - Long: `Send a test request to verify LLM engine is working.`, + Short: "Test LLM provider connection", + Long: `Send a test request to verify LLM provider is working.`, Run: runLLMTest, } +var llmSetupCmd = &cobra.Command{ + Use: "setup", + Short: "Show LLM setup instructions", + Long: `Display instructions for configuring LLM providers.`, + Run: runLLMSetup, +} + func init() { rootCmd.AddCommand(llmCmd) - llmCmd.AddCommand(llmSetupCmd) llmCmd.AddCommand(llmStatusCmd) llmCmd.AddCommand(llmTestCmd) + llmCmd.AddCommand(llmSetupCmd) } -func runLLMSetup(_ *cobra.Command, _ []string) { - fmt.Println("🤖 LLM Engine Configuration") +func runLLMStatus(_ *cobra.Command, _ []string) { + ui.PrintTitle("LLM", "Provider Status") fmt.Println() - // Load current config - cfg := llm.LoadLLMConfig() + // Load config + cfg := llm.LoadConfig() - // Show current settings - fmt.Println("Current settings:") - fmt.Printf(" Engine mode: %s\n", cfg.Backend) - if cfg.CLI != "" { - fmt.Printf(" CLI: %s\n", cfg.CLI) + fmt.Println("Configuration:") + if cfg.Provider != "" { + fmt.Printf(" Provider: %s\n", cfg.Provider) + } else { + fmt.Println(" Provider: (not configured)") } if cfg.Model != "" { fmt.Printf(" Model: %s\n", cfg.Model) } - if cfg.HasAPIKey() { - fmt.Println(" API Key: configured") - } else { - fmt.Println(" API Key: not set") - } fmt.Println() - // Show menu - items := []string{ - "Configure CLI tool", - "Set OpenAI API key", - "Change engine mode", - "Test current configuration", - "Reset to defaults", - "Exit", - } - - templates := &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "✓ {{ . | green }}", - } - - selectPrompt := promptui.Select{ - Label: "What would you like to configure", - Items: items, - Templates: templates, - Size: 6, + // Show available providers + fmt.Println("Available providers:") + providers := llm.ListProviders() + for _, p := range providers { + status := "not available" + if p.Available { + status = "available" + if p.Path != "" { + status = fmt.Sprintf("available (%s)", p.Path) + } + } + fmt.Printf(" %s: %s\n", p.DisplayName, status) } + fmt.Println() - index, _, err := selectPrompt.Run() + // Try to create provider + provider, err := llm.New(cfg) if err != nil { - fmt.Println("\nSetup cancelled") - return + ui.PrintWarn(fmt.Sprintf("Configuration error: %v", err)) + } else { + ui.PrintOK(fmt.Sprintf("Active provider: %s", provider.Name())) } - switch index { - case 0: - configureCLI(cfg) - case 1: - promptAPIKeySetup() - case 2: - configureEngineMode(cfg) - case 3: - runLLMTest(nil, nil) - case 4: - resetLLMConfig() - case 5: - fmt.Println("\nExiting setup") - } + fmt.Println() + fmt.Println("Run 'sym llm setup' for configuration instructions") + fmt.Println("Run 'sym llm test' to verify connection") } -func configureCLI(cfg *llm.LLMConfig) { - fmt.Println("\n🔧 CLI Tool Configuration") +func runLLMTest(_ *cobra.Command, _ []string) { + ui.PrintTitle("LLM", "Testing Provider Connection") fmt.Println() - // Detect available CLIs - clis := engine.DetectAvailableCLIs() - - // Build selection items - var items []string - var availableCLIs []engine.CLIInfo - - for _, cli := range clis { - status := "✗ not found" - if cli.Available { - status = "✓ available" - if cli.Version != "" { - status = fmt.Sprintf("✓ %s", cli.Version) - } - } - items = append(items, fmt.Sprintf("%s (%s)", cli.Name, status)) - availableCLIs = append(availableCLIs, cli) - } - - items = append(items, "Skip CLI configuration") - - templates := &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "✓ {{ . | green }}", - } - - selectPrompt := promptui.Select{ - Label: "Select CLI tool to use", - Items: items, - Templates: templates, - Size: len(items), - } + // Load config + cfg := llm.LoadConfig() - index, _, err := selectPrompt.Run() - if err != nil || index >= len(availableCLIs) { - fmt.Println("\nCLI configuration skipped") + // Create provider + provider, err := llm.New(cfg) + if err != nil { + ui.PrintError(fmt.Sprintf("Failed to create provider: %v", err)) + fmt.Println() + fmt.Println("Please configure a provider:") + fmt.Println(" sym llm setup") return } - selectedCLI := availableCLIs[index] - - if !selectedCLI.Available { - fmt.Printf("\n⚠️ %s is not installed or not in PATH\n", selectedCLI.Name) - fmt.Println("Please install it first and try again") - return - } + fmt.Printf("Testing provider: %s\n\n", provider.Name()) - // Update config - cfg.CLI = string(selectedCLI.Provider) + // Create test request + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - // Get provider for default model - provider, _ := engine.GetProvider(selectedCLI.Provider) - if provider != nil { - cfg.Model = provider.DefaultModel - cfg.LargeModel = provider.LargeModel - } + prompt := "You are a helpful assistant. Respond with exactly one word.\n\nSay 'OK' to confirm you are working." + response, err := provider.Execute(ctx, prompt, llm.Text) - // Save config - if err := llm.SaveLLMConfig(cfg); err != nil { - fmt.Printf("\n❌ Failed to save configuration: %v\n", err) + if err != nil { + ui.PrintError(fmt.Sprintf("Test failed: %v", err)) return } - fmt.Printf("\n✓ CLI engine configured: %s\n", selectedCLI.Name) - if cfg.Model != "" { - fmt.Printf(" Default model: %s\n", cfg.Model) - } - if cfg.LargeModel != "" { - fmt.Printf(" Large model: %s\n", cfg.LargeModel) - } - fmt.Println(" Configuration saved to .sym/.env") + ui.PrintOK("Test successful!") + fmt.Printf(" Response: %s\n", strings.TrimSpace(response)) } -func configureEngineMode(cfg *llm.LLMConfig) { - fmt.Println("\n⚙️ Engine Mode Configuration") +func runLLMSetup(_ *cobra.Command, _ []string) { + ui.PrintTitle("LLM", "Provider Setup Instructions") fmt.Println() - items := []string{ - "auto - Automatically select best available engine", - "mcp - Always use MCP sampling (when available)", - "cli - Always use CLI tool", - "api - Always use OpenAI API", + // Show available providers + fmt.Println("Available providers:") + providers := llm.ListProviders() + for _, p := range providers { + status := "not installed" + if p.Available { + status = "ready" + } + fmt.Printf(" %s (%s): %s\n", p.Name, p.DisplayName, status) } + fmt.Println() - templates := &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "✓ {{ . | green }}", - } + fmt.Println("Configuration files:") + fmt.Println(" .sym/config.json - Provider and model settings (safe to commit)") + fmt.Println(" .sym/.env - API keys only (gitignored)") + fmt.Println() - selectPrompt := promptui.Select{ - Label: "Select engine mode", - Items: items, - Templates: templates, - Size: 4, - } + fmt.Println("Example .sym/config.json:") + fmt.Println(` { + "llm": { + "provider": "claudecode", + "model": "sonnet" + } + }`) + fmt.Println() - index, _, err := selectPrompt.Run() - if err != nil { - fmt.Println("\nEngine mode configuration cancelled") - return + // Dynamically generate model aliases from registry + fmt.Println("Supported model aliases:") + for _, p := range providers { + if len(p.Models) > 0 { + modelIDs := make([]string, 0, len(p.Models)) + for _, m := range p.Models { + modelIDs = append(modelIDs, m.ID) + } + fmt.Printf(" %s: %s\n", p.DisplayName, strings.Join(modelIDs, ", ")) + } } + fmt.Println() - modes := []engine.Mode{ - engine.ModeAuto, - engine.ModeMCP, - engine.ModeCLI, - engine.ModeAPI, + // Show API key instructions for providers that require them + for _, p := range providers { + if p.APIKey.Required && p.APIKey.EnvVarName != "" { + fmt.Printf("For %s, also add to .sym/.env:\n", p.DisplayName) + fmt.Printf(" %s=%s...\n", p.APIKey.EnvVarName, p.APIKey.Prefix) + fmt.Println() + } } - cfg.Backend = modes[index] + fmt.Println("After configuration, run 'sym llm test' to verify.") +} - // Save config - if err := llm.SaveLLMConfig(cfg); err != nil { - fmt.Printf("\n❌ Failed to save configuration: %v\n", err) - return - } +// promptLLMBackendSetup is called from init command to setup LLM provider. +func promptLLMBackendSetup() { + // Use custom template to hide "type to filter" and typed characters + restore := useSelectTemplateNoFilter() + defer restore() - fmt.Printf("\n✓ Engine mode set to: %s\n", cfg.Backend) -} + fmt.Println() + ui.PrintTitle("LLM", "Configure LLM Provider") + fmt.Println(ui.Indent("Symphony uses LLM for policy conversion and code validation")) + fmt.Println() -func resetLLMConfig() { - fmt.Println("\n🔄 Resetting LLM Configuration") + // Get provider options dynamically from registry + providerOptions := llm.GetProviderOptions(true) // includes "Skip" - // Confirm - prompt := promptui.Prompt{ - Label: "Are you sure you want to reset LLM configuration", - IsConfirm: true, + // Select provider + var selectedDisplayName string + providerPrompt := &survey.Select{ + Message: "Select LLM provider:", + Options: providerOptions, } - result, err := prompt.Run() - if err != nil || strings.ToLower(result) != "y" { - fmt.Println("\nReset cancelled") + if err := survey.AskOne(providerPrompt, &selectedDisplayName); err != nil { + fmt.Println("Skipped LLM configuration") return } - // Save default config - cfg := llm.DefaultLLMConfig() - if err := llm.SaveLLMConfig(cfg); err != nil { - fmt.Printf("\n❌ Failed to reset configuration: %v\n", err) + if selectedDisplayName == "Skip" { + fmt.Println("Skipped LLM configuration") + fmt.Println(ui.Indent("Tip: Run 'sym init --setup-llm' to configure later")) return } - fmt.Println("\n✓ LLM configuration reset to defaults") -} - -func runLLMStatus(_ *cobra.Command, _ []string) { - fmt.Println("🤖 LLM Engine Status") - fmt.Println() - - // Load config - cfg := llm.LoadLLMConfig() - - // Create client to check engines - client := llm.NewClient(llm.WithConfig(cfg), llm.WithVerbose(false)) - - fmt.Println("Configuration:") - fmt.Printf(" Engine mode: %s\n", cfg.Backend) - if cfg.CLI != "" { - fmt.Printf(" CLI provider: %s\n", cfg.CLI) - } - if cfg.Model != "" { - fmt.Printf(" Model: %s\n", cfg.Model) + // Get provider info from registry + providerInfo := llm.GetProviderByDisplayName(selectedDisplayName) + if providerInfo == nil { + ui.PrintError(fmt.Sprintf("Unknown provider: %s", selectedDisplayName)) + return } - fmt.Println() - // Show engine availability - fmt.Println("Engine availability:") + providerName := providerInfo.Name + var modelID string - engines := client.GetEngines() - if len(engines) == 0 { - fmt.Println(" ⚠️ No engines configured") - } else { - for _, e := range engines { - status := "✗ unavailable" - if e.IsAvailable() { - status = "✓ available" - } - fmt.Printf(" %s: %s\n", e.Name(), status) + // Handle API key if required + if llm.RequiresAPIKey(providerName) { + if err := promptAndSaveAPIKey(providerName); err != nil { + ui.PrintError(fmt.Sprintf("Failed to save API key: %v", err)) + return } } - fmt.Println() - - // Show active engine - active := client.GetActiveEngine() - if active != nil { - fmt.Printf("Active engine: %s\n", active.Name()) - - caps := active.Capabilities() - fmt.Println("Capabilities:") - fmt.Printf(" Temperature: %v\n", caps.SupportsTemperature) - fmt.Printf(" Max tokens: %v\n", caps.SupportsMaxTokens) - fmt.Printf(" Complexity hint: %v\n", caps.SupportsComplexity) + // Select model (common for all providers) + modelOptions := llm.GetModelOptions(providerName) + if len(modelOptions) > 0 { + var selectedOption string + modelPrompt := &survey.Select{ + Message: fmt.Sprintf("Select %s model:", providerInfo.DisplayName), + Options: modelOptions, + Default: llm.GetDefaultModelOption(providerName), + } + if err := survey.AskOne(modelPrompt, &selectedOption); err != nil { + fmt.Println("Skipped model selection, using default") + modelID = providerInfo.DefaultModel + } else { + modelID = llm.GetModelIDFromOption(providerName, selectedOption) + } } else { - fmt.Println("⚠️ No active engine available") + modelID = providerInfo.DefaultModel } - fmt.Println() - fmt.Println("💡 Run 'sym llm setup' to configure engines") - fmt.Println("💡 Run 'sym llm test' to verify connection") -} - -func runLLMTest(_ *cobra.Command, _ []string) { - fmt.Println("🧪 Testing LLM Engine Connection") - fmt.Println() - - // Load config - cfg := llm.LoadLLMConfig() - - // Create client - client := llm.NewClient(llm.WithConfig(cfg), llm.WithVerbose(true)) - - active := client.GetActiveEngine() - if active == nil { - fmt.Println("❌ No LLM engine available") - fmt.Println() - fmt.Println("Please configure an engine:") - fmt.Println(" sym llm setup") + // Save to config.json + if err := config.UpdateProjectConfigLLM(providerName, modelID); err != nil { + ui.PrintError(fmt.Sprintf("Failed to save config: %v", err)) return } - fmt.Printf("Testing engine: %s\n\n", active.Name()) - - // Create test request - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - response, err := client.Request( - "You are a helpful assistant. Respond with exactly one word.", - "Say 'OK' to confirm you are working.", - ).Execute(ctx) - - if err != nil { - fmt.Printf("\n❌ Test failed: %v\n", err) - os.Exit(1) - } - - fmt.Printf("\n✓ Test successful!\n") - fmt.Printf(" Response: %s\n", strings.TrimSpace(response)) + ui.PrintOK(fmt.Sprintf("LLM provider saved: %s (%s)", selectedDisplayName, modelID)) } -// promptLLMBackendSetup is called from init command to setup LLM engine. -func promptLLMBackendSetup() { - fmt.Println("\n🤖 LLM Engine Configuration") - fmt.Println(" Symphony uses LLM for policy conversion and code validation.") - fmt.Println() - - // Detect available CLIs - clis := engine.DetectAvailableCLIs() - - // Check API key - cfg := llm.LoadLLMConfig() - hasAPIKey := cfg.HasAPIKey() - - // Show detected tools - fmt.Println(" Detected LLM tools:") - hasAnyCLI := false - for _, cli := range clis { - status := "✗" - if cli.Available { - status = "✓" - hasAnyCLI = true - } - version := "" - if cli.Version != "" { - version = fmt.Sprintf(" (%s)", cli.Version) - } - fmt.Printf(" %s %s%s\n", status, cli.Name, version) +// promptAndSaveAPIKey prompts for API key and saves to .env +func promptAndSaveAPIKey(providerName string) error { + envVarName := llm.GetAPIKeyEnvVar(providerName) + if envVarName == "" { + return fmt.Errorf("provider %s does not have API key configuration", providerName) } - if hasAPIKey { - fmt.Println(" ✓ OpenAI API key (configured)") - } else { - fmt.Println(" ✗ OpenAI API key (not set)") + var apiKey string + prompt := &survey.Password{ + Message: fmt.Sprintf("Enter your %s:", envVarName), } - fmt.Println() - // If nothing available, skip - if !hasAnyCLI && !hasAPIKey { - fmt.Println(" ⚠️ No LLM engine available") - fmt.Println(" You can configure one later with: sym llm setup") - return + if err := survey.AskOne(prompt, &apiKey); err != nil { + return err } - // Build selection items - var items []string - var modes []engine.Mode - - items = append(items, "Auto (recommended) - Use best available engine") - modes = append(modes, engine.ModeAuto) - - for _, cli := range clis { - if cli.Available { - items = append(items, fmt.Sprintf("%s CLI", cli.Name)) - modes = append(modes, engine.ModeCLI) + // Validate API key using registry + if err := llm.ValidateAPIKey(providerName, apiKey); err != nil { + ui.PrintWarn(err.Error()) + // Continue anyway - it's a warning, not a blocking error + // But if the key is empty, we should return the error + if apiKey == "" { + return err } } - if hasAPIKey { - items = append(items, "OpenAI API") - modes = append(modes, engine.ModeAPI) + // Save to .env file + envPath := config.GetProjectEnvPath() + if err := saveAPIKeyToEnv(envPath, envVarName, apiKey); err != nil { + return err } - items = append(items, "Skip (configure later)") - modes = append(modes, "") + ui.PrintOK("API key saved to .sym/.env (gitignored)") - templates := &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "✓ {{ . | green }}", + // Ensure .env is in .gitignore + if err := ensureGitignore(".sym/.env"); err != nil { + ui.PrintWarn(fmt.Sprintf("Failed to update .gitignore: %v", err)) } - selectPrompt := promptui.Select{ - Label: "Select your preferred LLM engine", - Items: items, - Templates: templates, - Size: len(items), - Stdout: &bellSkipper{}, - } + return nil +} - index, _, err := selectPrompt.Run() - if err != nil || modes[index] == "" { - fmt.Println("\n LLM engine configuration skipped") - fmt.Println(" Run 'sym llm setup' to configure later") - return - } +// saveAPIKeyToEnv saves the API key to the .env file +func saveAPIKeyToEnv(envPath, envVarName, apiKey string) error { + return envutil.SaveKeyToEnvFile(envPath, envVarName, apiKey) +} - // Update config - cfg.Backend = modes[index] - - // If CLI selected, set the specific CLI provider - if modes[index] == engine.ModeCLI { - // Find which CLI was selected - cliIndex := index - 1 // Account for "Auto" option - cliCount := 0 - for _, cli := range clis { - if cli.Available { - if cliCount == cliIndex { - cfg.CLI = string(cli.Provider) - provider, _ := engine.GetProvider(cli.Provider) - if provider != nil { - cfg.Model = provider.DefaultModel - cfg.LargeModel = provider.LargeModel - } - break - } - cliCount++ +// ensureGitignore ensures that the given path is in .gitignore +func ensureGitignore(path string) error { + gitignorePath := ".gitignore" + + // Read existing .gitignore + var lines []string + existingFile, err := os.Open(gitignorePath) + if err == nil { + scanner := bufio.NewScanner(existingFile) + for scanner.Scan() { + line := scanner.Text() + lines = append(lines, line) + // Check if already exists + if strings.TrimSpace(line) == path { + _ = existingFile.Close() + return nil // Already in .gitignore } } + _ = existingFile.Close() } - // Save config - if err := llm.SaveLLMConfig(cfg); err != nil { - fmt.Printf("\n ⚠️ Failed to save LLM configuration: %v\n", err) - return - } + // Add to .gitignore + lines = append(lines, "", "# Symphony API key configuration", path) + content := strings.Join(lines, "\n") + "\n" - fmt.Printf("\n ✓ LLM engine set to: %s\n", cfg.Backend) - if cfg.CLI != "" { - fmt.Printf(" CLI: %s\n", cfg.CLI) + if err := os.WriteFile(gitignorePath, []byte(content), 0644); err != nil { + return fmt.Errorf("failed to update .gitignore: %w", err) } - fmt.Println(" Configuration saved to .sym/.env") + + return nil } + diff --git a/internal/cmd/mcp.go b/internal/cmd/mcp.go index 3b726bf..2c348bf 100644 --- a/internal/cmd/mcp.go +++ b/internal/cmd/mcp.go @@ -8,6 +8,7 @@ import ( "github.com/DevSymphony/sym-cli/internal/git" "github.com/DevSymphony/sym-cli/internal/mcp" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/pkg/browser" "github.com/spf13/cobra" ) @@ -60,15 +61,16 @@ func runMCP(cmd *cobra.Command, args []string) error { // If no user-policy.json → Launch dashboard if !userPolicyExists { - fmt.Println("❌ User policy not found at:", userPolicyPath) - fmt.Println("📝 Opening dashboard to create policy...") + ui.PrintError(fmt.Sprintf("User policy not found at: %s", userPolicyPath)) + fmt.Println("Opening dashboard to create policy...") // Launch dashboard if err := launchDashboard(); err != nil { return fmt.Errorf("failed to launch dashboard: %w", err) } - fmt.Println("\n✓ Dashboard launched at http://localhost:8787") + fmt.Println() + ui.PrintOK("Dashboard launched at http://localhost:8787") fmt.Println("Please create your policy in the dashboard, then restart MCP server.") return nil } diff --git a/internal/cmd/mcp_register.go b/internal/cmd/mcp_register.go index a4da0b7..1baa5a4 100644 --- a/internal/cmd/mcp_register.go +++ b/internal/cmd/mcp_register.go @@ -9,7 +9,8 @@ import ( "runtime" "strings" - "github.com/manifoldco/promptui" + "github.com/AlecAivazis/survey/v2" + "github.com/DevSymphony/sym-cli/internal/ui" ) // MCPRegistrationConfig represents the MCP configuration structure @@ -41,159 +42,88 @@ type VSCodeServerConfig struct { Env map[string]string `json:"env,omitempty"` } -// editorItem represents an editor option with selection state -type editorItem struct { - Name string - AppID string - Selected bool - IsSubmit bool - IsSkip bool +// MCP tool options for multi-select +var mcpToolOptions = []string{ + "Claude Code", + "Cursor", + "VS Code Copilot", } -// bellSkipper wraps os.Stdout to skip bell characters (prevents alert sound) -type bellSkipper struct{} - -func (bs *bellSkipper) Write(b []byte) (int, error) { - const bell = '\a' - // Filter out bell characters - filtered := make([]byte, 0, len(b)) - for _, c := range b { - if c != bell { - filtered = append(filtered, c) - } - } - return os.Stdout.Write(filtered) -} - -func (bs *bellSkipper) Close() error { - return nil +// mcpToolToApp maps display name to internal app identifier +var mcpToolToApp = map[string]string{ + "Claude Code": "claude-code", + "Cursor": "cursor", + "VS Code Copilot": "vscode", } // promptMCPRegistration prompts user to register Symphony as MCP server func promptMCPRegistration() { // Check if npx is available if !checkNpxAvailable() { - fmt.Println("\n⚠ Warning: 'npx' not found. MCP features require Node.js.") - fmt.Println(" Download: https://nodejs.org/") + ui.PrintWarn("'npx' not found. MCP features require Node.js.") + fmt.Println(ui.Indent("Download: https://nodejs.org/")) - confirmPrompt := promptui.Prompt{ - Label: "Continue anyway", - IsConfirm: true, + var continueAnyway bool + prompt := &survey.Confirm{ + Message: "Continue anyway?", + Default: false, } - - result, err := confirmPrompt.Run() - if err != nil || strings.ToLower(result) != "y" { + if err := survey.AskOne(prompt, &continueAnyway); err != nil || !continueAnyway { fmt.Println("Skipped MCP registration") return } } - fmt.Println("\n📡 Would you like to register Symphony as an MCP server?") - fmt.Println(" (Symphony MCP provides code convention tools for AI assistants)") - fmt.Println(" Press Enter to toggle selection, then select Submit to apply") + // Use custom template to hide "type to filter" and typed characters + restore := useMultiSelectTemplateNoFilter() + defer restore() + + fmt.Println() + ui.PrintTitle("MCP", "Register Symphony as an MCP server") + fmt.Println(ui.Indent("Symphony MCP provides code convention tools for AI assistants")) + fmt.Println(ui.Indent("(Use arrows to move, space to select, enter to submit)")) fmt.Println() - // Initialize editor items - items := []editorItem{ - {Name: "Claude Desktop (global)", AppID: "claude-desktop"}, - {Name: "Claude Code (project)", AppID: "claude-code"}, - {Name: "Cursor (project)", AppID: "cursor"}, - {Name: "VS Code Copilot (project)", AppID: "vscode"}, - {Name: "Cline (global)", AppID: "cline"}, - {Name: "Submit", IsSubmit: true}, + // Multi-select prompt for tools + var selectedTools []string + prompt := &survey.MultiSelect{ + Message: "Select vibe coding tools to integrate:", + Options: mcpToolOptions, } - // Track cursor position across loop iterations - cursorPos := 0 - - // Multi-select loop - for { - // Count selected items first - selectedCount := 0 - for _, item := range items { - if item.Selected { - selectedCount++ - } - } - - // Build display items with checkboxes - displayItems := make([]string, len(items)) - for i, item := range items { - if item.IsSubmit { - if selectedCount == 0 { - displayItems[i] = "⏭ Skip" - } else { - displayItems[i] = fmt.Sprintf("✅ Submit (%d selected)", selectedCount) - } - } else { - checkbox := "☐" - if item.Selected { - checkbox = "☑" - } - displayItems[i] = fmt.Sprintf("%s %s", checkbox, item.Name) - } - } - - templates := &promptui.SelectTemplates{ - Label: "{{ . }}", - Active: "▸ {{ . | cyan }}", - Inactive: " {{ . }}", - Selected: "{{ . }}", - } - - prompt := promptui.Select{ - Label: "Select editors (Enter to toggle)", - Items: displayItems, - Templates: templates, - Size: 6, - HideSelected: true, - CursorPos: cursorPos, - Stdout: &bellSkipper{}, - } - - index, _, err := prompt.Run() - if err != nil { - fmt.Println("\nSkipped MCP registration") - return - } - - selectedItem := &items[index] + if err := survey.AskOne(prompt, &selectedTools); err != nil { + fmt.Println("Skipped MCP registration") + return + } - if selectedItem.IsSubmit { - // Collect selected apps - var selectedApps []string - for _, item := range items { - if item.Selected && item.AppID != "" { - selectedApps = append(selectedApps, item.AppID) - } - } + // If no tools selected, skip + if len(selectedTools) == 0 { + fmt.Println("Skipped MCP registration") + fmt.Println(ui.Indent("Tip: Run 'sym init --register-mcp' to register MCP later")) + return + } - // If no editors selected, act as Skip - if len(selectedApps) == 0 { - fmt.Println("Skipped MCP registration") - fmt.Println("\n💡 Tip: Run 'sym init --register-mcp' to register MCP later") - return - } + // Register selected tools + var registered []string + var failed []string - // Register all selected apps - successCount := 0 - for _, appID := range selectedApps { - if registerMCP(appID) == nil { - successCount++ - } - } - - if successCount > 0 { - fmt.Printf("\n✅ MCP registration complete! Registered to %d app(s).\n", successCount) - fmt.Println(" Restart/reload the apps to use Symphony.") - } - return + for _, tool := range selectedTools { + app := mcpToolToApp[tool] + if err := registerMCP(app); err != nil { + failed = append(failed, fmt.Sprintf("%s: %v", tool, err)) + } else { + registered = append(registered, tool) } + } - // Toggle selection for editor items - selectedItem.Selected = !selectedItem.Selected - // Preserve cursor position for next iteration - cursorPos = index + // Print results + fmt.Println() + if len(registered) > 0 { + ui.PrintOK(fmt.Sprintf("Registered: %s", strings.Join(registered, ", "))) + fmt.Println(ui.Indent("Reload/restart the tools to use Symphony")) + } + for _, f := range failed { + ui.PrintError(fmt.Sprintf("Failed to register %s", f)) } } @@ -202,19 +132,13 @@ func registerMCP(app string) error { configPath := getMCPConfigPath(app) if configPath == "" { - fmt.Printf("\n⚠ %s config path could not be determined\n", getAppDisplayName(app)) + fmt.Println(ui.Warn(fmt.Sprintf("%s config path could not be determined", getAppDisplayName(app)))) return fmt.Errorf("config path not determined") } - // Check if this is a project-specific config - isProjectConfig := app != "claude-desktop" && app != "cline" - - if isProjectConfig { - fmt.Printf("\n✓ Configuring %s (project-specific)\n", getAppDisplayName(app)) - } else { - fmt.Printf("\n✓ Configuring %s (global)\n", getAppDisplayName(app)) - } - fmt.Printf(" Location: %s\n", configPath) + // All supported apps are now project-specific + fmt.Println(ui.Indent(fmt.Sprintf("Configuring %s", getAppDisplayName(app)))) + fmt.Println(ui.Indent(fmt.Sprintf("Location: %s", configPath))) // Create config directory if it doesn't exist configDir := filepath.Dir(configPath) @@ -237,22 +161,22 @@ func registerMCP(app string) error { // Invalid JSON, create backup backupPath := configPath + ".bak" if err := os.WriteFile(backupPath, existingData, 0644); err != nil { - fmt.Printf(" ⚠ Failed to create backup: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create backup: %v", err))) } else { - fmt.Printf(" ⚠ Invalid JSON, backup created: %s\n", filepath.Base(backupPath)) + fmt.Println(ui.Indent(fmt.Sprintf("Invalid JSON, backup created: %s", filepath.Base(backupPath)))) } vscodeConfig = VSCodeMCPConfig{} } else { // Valid JSON, create backup backupPath := configPath + ".bak" if err := os.WriteFile(backupPath, existingData, 0644); err != nil { - fmt.Printf(" ⚠ Failed to create backup: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create backup: %v", err))) } else { - fmt.Printf(" Backup: %s\n", filepath.Base(backupPath)) + fmt.Println(ui.Indent(fmt.Sprintf("Backup: %s", filepath.Base(backupPath)))) } } } else { - fmt.Printf(" Creating new configuration file\n") + fmt.Println(ui.Indent("Creating new configuration file")) } // Initialize Servers if nil @@ -281,22 +205,22 @@ func registerMCP(app string) error { // Invalid JSON, create backup backupPath := configPath + ".bak" if err := os.WriteFile(backupPath, existingData, 0644); err != nil { - fmt.Printf(" ⚠ Failed to create backup: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create backup: %v", err))) } else { - fmt.Printf(" ⚠ Invalid JSON, backup created: %s\n", filepath.Base(backupPath)) + fmt.Println(ui.Indent(fmt.Sprintf("Invalid JSON, backup created: %s", filepath.Base(backupPath)))) } config = MCPRegistrationConfig{} } else { // Valid JSON, create backup backupPath := configPath + ".bak" if err := os.WriteFile(backupPath, existingData, 0644); err != nil { - fmt.Printf(" ⚠ Failed to create backup: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create backup: %v", err))) } else { - fmt.Printf(" Backup: %s\n", filepath.Base(backupPath)) + fmt.Println(ui.Indent(fmt.Sprintf("Backup: %s", filepath.Base(backupPath)))) } } } else { - fmt.Printf(" Creating new configuration file\n") + fmt.Println(ui.Indent("Creating new configuration file")) } // Initialize MCPServers if nil @@ -329,14 +253,11 @@ func registerMCP(app string) error { return fmt.Errorf("failed to write config: %w", err) } - fmt.Printf(" ✓ Symphony MCP server registered\n") + fmt.Println(ui.Indent("Symphony MCP server registered")) - // Create instructions file for project-specific configs - // Note: Cline has global MCP config but project-specific .clinerules - if isProjectConfig || app == "cline" { - if err := createInstructionsFile(app); err != nil { - fmt.Printf(" ⚠ Failed to create instructions file: %v\n", err) - } + // Create instructions file for all supported apps + if err := createInstructionsFile(app); err != nil { + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create instructions file: %v", err))) } return nil @@ -448,25 +369,25 @@ func createInstructionsFile(app string) error { if appendMode { // Check if Symphony instructions already exist if strings.Contains(string(existingContent), "# Symphony Code Conventions") { - fmt.Printf(" ✓ Instructions already exist in %s\n", instructionsPath) + fmt.Println(ui.Indent(fmt.Sprintf("Instructions already exist in %s", instructionsPath))) return nil } // Append to existing file content = string(existingContent) + "\n\n" + content - fmt.Printf(" ✓ Appended Symphony instructions to %s\n", instructionsPath) + fmt.Println(ui.Indent(fmt.Sprintf("Appended Symphony instructions to %s", instructionsPath))) } else { // Create backup backupPath := instructionsPath + ".bak" if err := os.WriteFile(backupPath, existingContent, 0644); err != nil { - fmt.Printf(" ⚠ Failed to create backup: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Failed to create backup: %v", err))) } else { - fmt.Printf(" Backup: %s\n", filepath.Base(backupPath)) + fmt.Println(ui.Indent(fmt.Sprintf("Backup: %s", filepath.Base(backupPath)))) } - fmt.Printf(" ✓ Created %s\n", instructionsPath) + fmt.Println(ui.Indent(fmt.Sprintf("Created %s", instructionsPath))) } } else { // Create new file - fmt.Printf(" ✓ Created %s\n", instructionsPath) + fmt.Println(ui.Indent(fmt.Sprintf("Created %s", instructionsPath))) } // Create directory if needed @@ -486,9 +407,9 @@ func createInstructionsFile(app string) error { if app == "vscode" { gitignorePath := ".github/instructions/" if err := ensureGitignore(gitignorePath); err != nil { - fmt.Printf(" ⚠ Warning: Failed to update .gitignore: %v\n", err) + fmt.Println(ui.Indent(fmt.Sprintf("Warning: Failed to update .gitignore: %v", err))) } else { - fmt.Printf(" ✓ Added %s to .gitignore\n", gitignorePath) + fmt.Println(ui.Indent(fmt.Sprintf("Added %s to .gitignore", gitignorePath))) } } diff --git a/internal/cmd/my_role.go b/internal/cmd/my_role.go index a119837..8de2b2e 100644 --- a/internal/cmd/my_role.go +++ b/internal/cmd/my_role.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/DevSymphony/sym-cli/internal/roles" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/spf13/cobra" ) @@ -41,8 +42,8 @@ func runMyRole(cmd *cobra.Command, args []string) { output := map[string]string{"error": "roles.json not found"} _ = json.NewEncoder(os.Stdout).Encode(output) } else { - fmt.Println("❌ roles.json not found") - fmt.Println("Run 'sym init' first") + ui.PrintError("roles.json not found") + fmt.Println(ui.Indent("Run 'sym init' first")) } os.Exit(1) } @@ -67,13 +68,11 @@ func runMyRole(cmd *cobra.Command, args []string) { _ = json.NewEncoder(os.Stdout).Encode(output) } else { if role == "" { - fmt.Println("⚠ No role selected") - fmt.Println("Run 'sym my-role --select' to select a role") - fmt.Println("Or use the dashboard: 'sym dashboard'") + ui.PrintWarn("No role selected") + fmt.Println(ui.Indent("Run 'sym my-role --select' to select a role")) } else { fmt.Printf("Current role: %s\n", role) - fmt.Println("\nTo change your role:") - fmt.Println(" sym my-role --select") + fmt.Println(ui.Indent("Run 'sym my-role --select' to change")) } } } @@ -81,18 +80,18 @@ func runMyRole(cmd *cobra.Command, args []string) { func selectNewRole() { availableRoles, err := roles.GetAvailableRoles() if err != nil { - fmt.Printf("❌ Failed to get available roles: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to get available roles: %v", err)) os.Exit(1) } if len(availableRoles) == 0 { - fmt.Println("❌ No roles defined in roles.json") + ui.PrintError("No roles defined in roles.json") os.Exit(1) } currentRole, _ := roles.GetCurrentRole() - fmt.Println("🎭 Select your role:") + ui.PrintTitle("Role", "Select your role") fmt.Println() for i, role := range availableRoles { marker := " " @@ -109,23 +108,23 @@ func selectNewRole() { input = strings.TrimSpace(input) if input == "" { - fmt.Println("⚠ No selection made") + ui.PrintWarn("No selection made") return } num, err := strconv.Atoi(input) if err != nil || num < 1 || num > len(availableRoles) { - fmt.Println("❌ Invalid selection") + ui.PrintError("Invalid selection") os.Exit(1) } selectedRole := availableRoles[num-1] if err := roles.SetCurrentRole(selectedRole); err != nil { - fmt.Printf("❌ Failed to save role: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to save role: %v", err)) os.Exit(1) } - fmt.Printf("✓ Your role has been changed to: %s\n", selectedRole) + ui.PrintOK(fmt.Sprintf("Your role has been changed to: %s", selectedRole)) } func handleError(msg string, err error, jsonMode bool) { @@ -133,6 +132,6 @@ func handleError(msg string, err error, jsonMode bool) { output := map[string]string{"error": fmt.Sprintf("%s: %v", msg, err)} _ = json.NewEncoder(os.Stdout).Encode(output) } else { - fmt.Printf("❌ %s: %v\n", msg, err) + ui.PrintError(fmt.Sprintf("%s: %v", msg, err)) } } diff --git a/internal/cmd/policy.go b/internal/cmd/policy.go index e497f94..1e7b36d 100644 --- a/internal/cmd/policy.go +++ b/internal/cmd/policy.go @@ -3,8 +3,10 @@ package cmd import ( "fmt" "os" + "github.com/DevSymphony/sym-cli/internal/config" "github.com/DevSymphony/sym-cli/internal/policy" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/spf13/cobra" ) @@ -48,7 +50,7 @@ func init() { func runPolicyPath(cmd *cobra.Command, args []string) { cfg, err := config.LoadConfig() if err != nil { - fmt.Printf("❌ Failed to load config: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to load config: %v", err)) os.Exit(1) } @@ -56,11 +58,11 @@ func runPolicyPath(cmd *cobra.Command, args []string) { // Set new path cfg.PolicyPath = policyPathSet if err := config.SaveConfig(cfg); err != nil { - fmt.Printf("❌ Failed to save config: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to save config: %v", err)) os.Exit(1) } - fmt.Printf("✓ Policy path updated: %s\n", policyPathSet) + ui.PrintOK(fmt.Sprintf("Policy path updated: %s", policyPathSet)) } else { // Show current path policyPath := cfg.PolicyPath @@ -81,9 +83,9 @@ func runPolicyPath(cmd *cobra.Command, args []string) { if err != nil { fmt.Printf("Error checking file: %v\n", err) } else if exists { - fmt.Println("✓ Policy file exists") + ui.PrintOK("Policy file exists") } else { - fmt.Println("⚠ Policy file does not exist") + ui.PrintWarn("Policy file does not exist") } } } @@ -91,7 +93,7 @@ func runPolicyPath(cmd *cobra.Command, args []string) { func runPolicyValidate(cmd *cobra.Command, args []string) { cfg, err := config.LoadConfig() if err != nil { - fmt.Printf("❌ Failed to load config: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to load config: %v", err)) os.Exit(1) } @@ -99,16 +101,16 @@ func runPolicyValidate(cmd *cobra.Command, args []string) { policyData, err := policy.LoadPolicy(cfg.PolicyPath) if err != nil { - fmt.Printf("❌ Failed to load policy: %v\n", err) + ui.PrintError(fmt.Sprintf("Failed to load policy: %v", err)) os.Exit(1) } if err := policy.ValidatePolicy(policyData); err != nil { - fmt.Printf("❌ Validation failed: %v\n", err) + ui.PrintError(fmt.Sprintf("Validation failed: %v", err)) os.Exit(1) } - fmt.Println("✓ Policy file is valid") + ui.PrintOK("Policy file is valid") fmt.Printf(" Version: %s\n", policyData.Version) fmt.Printf(" Rules: %d\n", len(policyData.Rules)) diff --git a/internal/cmd/survey_templates.go b/internal/cmd/survey_templates.go new file mode 100644 index 0000000..611c539 --- /dev/null +++ b/internal/cmd/survey_templates.go @@ -0,0 +1,70 @@ +package cmd + +import "github.com/AlecAivazis/survey/v2" + +// Custom Select template that: +// 1. Removes "type to filter" hint +// 2. Hides typed characters (removes .FilterMessage) +// 3. Shows clear control instructions +var selectTemplateNoFilter = ` +{{- define "option"}} + {{- if eq .SelectedIndex .CurrentIndex }}{{color .Config.Icons.SelectFocus.Format }}{{ .Config.Icons.SelectFocus.Text }} {{else}}{{color "default"}} {{end}} + {{- .CurrentOpt.Value}}{{ if ne ($.GetDescription .CurrentOpt) "" }} - {{color "cyan"}}{{ $.GetDescription .CurrentOpt }}{{end}} + {{- color "reset"}} +{{end}} +{{- if .ShowHelp }}{{- color .Config.Icons.Help.Format }}{{ .Config.Icons.Help.Text }} {{ .Help }}{{color "reset"}}{{"\n"}}{{end}} +{{- color .Config.Icons.Question.Format }}{{ .Config.Icons.Question.Text }} {{color "reset"}} +{{- color "default+hb"}}{{ .Message }}{{color "reset"}} +{{- if .ShowAnswer}}{{color "cyan"}} {{.Answer}}{{color "reset"}}{{"\n"}} +{{- else}} + {{- " "}}{{- color "cyan"}}[Arrow keys: move, Enter: select]{{color "reset"}} + {{- "\n"}} + {{- range $ix, $option := .PageEntries}} + {{- template "option" $.IterateOption $ix $option}} + {{- end}} +{{- end}}` + +// Custom MultiSelect template that: +// 1. Removes "type to filter" hint +// 2. Hides typed characters (removes .FilterMessage) +// 3. Shows clear control instructions +var multiSelectTemplateNoFilter = ` +{{- define "option"}} + {{- if eq .SelectedIndex .CurrentIndex }}{{color .Config.Icons.SelectFocus.Format }}{{ .Config.Icons.SelectFocus.Text }}{{color "reset"}}{{else}} {{end}} + {{- if index .Checked .CurrentOpt.Index }}{{color .Config.Icons.MarkedOption.Format }} {{ .Config.Icons.MarkedOption.Text }} {{else}}{{color .Config.Icons.UnmarkedOption.Format }} {{ .Config.Icons.UnmarkedOption.Text }} {{end}} + {{- color "reset"}} + {{- " "}}{{- .CurrentOpt.Value}}{{ if ne ($.GetDescription .CurrentOpt) "" }} - {{color "cyan"}}{{ $.GetDescription .CurrentOpt }}{{color "reset"}}{{end}} +{{end}} +{{- if .ShowHelp }}{{- color .Config.Icons.Help.Format }}{{ .Config.Icons.Help.Text }} {{ .Help }}{{color "reset"}}{{"\n"}}{{end}} +{{- color .Config.Icons.Question.Format }}{{ .Config.Icons.Question.Text }} {{color "reset"}} +{{- color "default+hb"}}{{ .Message }}{{color "reset"}} +{{- if .ShowAnswer}}{{color "cyan"}} {{.Answer}}{{color "reset"}}{{"\n"}} +{{- else }} + {{- " "}}{{- color "cyan"}}[Arrow keys: move, Space: toggle, Enter: confirm]{{color "reset"}} + {{- "\n"}} + {{- range $ix, $option := .PageEntries}} + {{- template "option" $.IterateOption $ix $option}} + {{- end}} +{{- end}}` + +// useSelectTemplateNoFilter temporarily overrides the global Select template +// to hide "type to filter" and prevent typed characters from showing. +// Returns a restore function that must be called to restore the original template. +func useSelectTemplateNoFilter() func() { + original := survey.SelectQuestionTemplate + survey.SelectQuestionTemplate = selectTemplateNoFilter + return func() { + survey.SelectQuestionTemplate = original + } +} + +// useMultiSelectTemplateNoFilter temporarily overrides the global MultiSelect template +// to hide "type to filter" and prevent typed characters from showing. +// Returns a restore function that must be called to restore the original template. +func useMultiSelectTemplateNoFilter() func() { + original := survey.MultiSelectQuestionTemplate + survey.MultiSelectQuestionTemplate = multiSelectTemplateNoFilter + return func() { + survey.MultiSelectQuestionTemplate = original + } +} diff --git a/internal/cmd/validate.go b/internal/cmd/validate.go index b805eea..0e5a09b 100644 --- a/internal/cmd/validate.go +++ b/internal/cmd/validate.go @@ -6,9 +6,9 @@ import ( "fmt" "os" "path/filepath" - "time" "github.com/DevSymphony/sym-cli/internal/llm" + "github.com/DevSymphony/sym-cli/internal/ui" "github.com/DevSymphony/sym-cli/internal/validator" "github.com/DevSymphony/sym-cli/pkg/schema" "github.com/spf13/cobra" @@ -73,17 +73,11 @@ func runValidate(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to parse policy: %w", err) } - // Create LLM client - llmClient := llm.NewClient( - llm.WithTimeout(time.Duration(validateTimeout) * time.Second), - ) - - // Ensure at least one backend is available (MCP/CLI/API) - availabilityCtx, cancel := context.WithTimeout(context.Background(), time.Duration(validateTimeout)*time.Second) - defer cancel() - - if err := llmClient.CheckAvailability(availabilityCtx); err != nil { - return fmt.Errorf("no available LLM backend for validate: %w\nTip: run 'sym init --setup-llm' or configure LLM_BACKEND / LLM_CLI / OPENAI_API_KEY in .sym/.env", err) + // Create LLM provider + cfg := llm.LoadConfig() + llmProvider, err := llm.New(cfg) + if err != nil { + return fmt.Errorf("no available LLM backend for validate: %w\nTip: configure provider in .sym/config.json", err) } var changes []validator.GitChange @@ -110,7 +104,7 @@ func runValidate(cmd *cobra.Command, args []string) error { // Create unified validator that handles all engines + RBAC v := validator.NewValidator(&policy, true) // verbose=true for CLI - v.SetLLMClient(llmClient) + v.SetLLMProvider(llmProvider) defer func() { if err := v.Close(); err != nil { fmt.Printf("Warning: failed to close validator: %v\n", err) @@ -141,7 +135,7 @@ func printValidationResult(result *validator.ValidationResult) { fmt.Printf("Failed: %d\n\n", result.Failed) if len(result.Violations) == 0 { - fmt.Println("✓ All checks passed!") + ui.PrintOK("All checks passed") return } diff --git a/internal/config/project.go b/internal/config/project.go new file mode 100644 index 0000000..8e576d4 --- /dev/null +++ b/internal/config/project.go @@ -0,0 +1,115 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// ProjectConfig represents the .sym/config.json structure +type ProjectConfig struct { + LLM LLMConfig `json:"llm,omitempty"` + MCP MCPConfig `json:"mcp,omitempty"` + PolicyPath string `json:"policy_path,omitempty"` +} + +// LLMConfig holds LLM provider settings +type LLMConfig struct { + Provider string `json:"provider,omitempty"` // "claudecode", "geminicli", "openaiapi" + Model string `json:"model,omitempty"` // Model name +} + +// MCPConfig holds MCP tool registration settings +type MCPConfig struct { + Tools []string `json:"tools,omitempty"` // ["vscode", "claude-code", "cursor"] +} + +const ( + symDir = ".sym" + projectConfigFile = "config.json" + projectEnvFile = ".env" +) + +// GetProjectConfigPath returns the path to .sym/config.json +func GetProjectConfigPath() string { + return filepath.Join(symDir, projectConfigFile) +} + +// GetProjectEnvPath returns the path to .sym/.env +func GetProjectEnvPath() string { + return filepath.Join(symDir, projectEnvFile) +} + +// LoadProjectConfig loads the project configuration from .sym/config.json +func LoadProjectConfig() (*ProjectConfig, error) { + configPath := GetProjectConfigPath() + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + // Return empty config if file doesn't exist + return &ProjectConfig{ + PolicyPath: ".sym/user-policy.json", + }, nil + } + return nil, fmt.Errorf("failed to read config: %w", err) + } + + var cfg ProjectConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("invalid config file: %w", err) + } + + return &cfg, nil +} + +// SaveProjectConfig saves the project configuration to .sym/config.json +func SaveProjectConfig(cfg *ProjectConfig) error { + // Ensure .sym directory exists + if err := os.MkdirAll(symDir, 0755); err != nil { + return fmt.Errorf("failed to create .sym directory: %w", err) + } + + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + configPath := GetProjectConfigPath() + if err := os.WriteFile(configPath, data, 0644); err != nil { + return fmt.Errorf("failed to write config: %w", err) + } + + return nil +} + +// UpdateProjectConfigLLM updates only the LLM section of project config +func UpdateProjectConfigLLM(provider, model string) error { + cfg, err := LoadProjectConfig() + if err != nil { + cfg = &ProjectConfig{} + } + + cfg.LLM.Provider = provider + cfg.LLM.Model = model + + return SaveProjectConfig(cfg) +} + +// UpdateProjectConfigMCP updates only the MCP section of project config +func UpdateProjectConfigMCP(tools []string) error { + cfg, err := LoadProjectConfig() + if err != nil { + cfg = &ProjectConfig{} + } + + cfg.MCP.Tools = tools + + return SaveProjectConfig(cfg) +} + +// ProjectConfigExists checks if .sym/config.json exists +func ProjectConfigExists() bool { + _, err := os.Stat(GetProjectConfigPath()) + return err == nil +} diff --git a/internal/converter/converter.go b/internal/converter/converter.go index e32836e..1de8e55 100644 --- a/internal/converter/converter.go +++ b/internal/converter/converter.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strings" "sync" @@ -20,18 +21,18 @@ const llmValidatorEngine = "llm-validator" // Converter is the main converter with language-based routing type Converter struct { - llmClient *llm.Client - outputDir string + llmProvider llm.Provider + outputDir string } // NewConverter creates a new converter instance -func NewConverter(llmClient *llm.Client, outputDir string) *Converter { +func NewConverter(provider llm.Provider, outputDir string) *Converter { if outputDir == "" { outputDir = ".sym" } return &Converter{ - llmClient: llmClient, - outputDir: outputDir, + llmProvider: provider, + outputDir: outputDir, } } @@ -89,6 +90,7 @@ func (c *Converter) Convert(ctx context.Context, userPolicy *schema.UserPolicy) } // Step 4: Convert rules for each linter in parallel using goroutines + // Limit concurrency to CPU count to prevent CPU spike result := &ConvertResult{ GeneratedFiles: []string{}, CodePolicy: codePolicy, @@ -99,6 +101,13 @@ func (c *Converter) Convert(ctx context.Context, userPolicy *schema.UserPolicy) var wg sync.WaitGroup var mu sync.Mutex + // Track failed rules per linter for fallback to llm-validator + failedRulesPerLinter := make(map[string][]string) + + // Limit concurrent LLM calls to CPU count + maxConcurrent := runtime.NumCPU() + sem := make(chan struct{}, maxConcurrent) + for linterName, rules := range linterRules { if len(rules) == 0 { continue @@ -113,46 +122,105 @@ func (c *Converter) Convert(ctx context.Context, userPolicy *schema.UserPolicy) go func(linter string, ruleSet []schema.UserRule) { defer wg.Done() + // Acquire semaphore + sem <- struct{}{} + defer func() { <-sem }() + // Get linter converter converter := c.getLinterConverter(linter) if converter == nil { mu.Lock() result.Errors[linter] = fmt.Errorf("unsupported linter: %s", linter) + // Mark all rules as failed for this linter + for _, rule := range ruleSet { + failedRulesPerLinter[linter] = append(failedRulesPerLinter[linter], rule.ID) + } mu.Unlock() return } - // Convert rules using LLM - configFile, err := converter.ConvertRules(ctx, ruleSet, c.llmClient) + // Convert rules using LLM - now returns ConversionResult with per-rule tracking + convResult, err := converter.ConvertRules(ctx, ruleSet, c.llmProvider) if err != nil { mu.Lock() result.Errors[linter] = fmt.Errorf("conversion failed: %w", err) + // Total failure - all rules should fallback + for _, rule := range ruleSet { + failedRulesPerLinter[linter] = append(failedRulesPerLinter[linter], rule.ID) + } mu.Unlock() return } - // Write config file to .sym directory - outputPath := filepath.Join(c.outputDir, configFile.Filename) - if err := os.WriteFile(outputPath, configFile.Content, 0644); err != nil { + // Track failed rules from partial conversion + if len(convResult.FailedRules) > 0 { mu.Lock() - result.Errors[linter] = fmt.Errorf("failed to write file: %w", err) + failedRulesPerLinter[linter] = append(failedRulesPerLinter[linter], convResult.FailedRules...) mu.Unlock() - return } - mu.Lock() - result.GeneratedFiles = append(result.GeneratedFiles, outputPath) - mu.Unlock() + // Only write config if at least one rule succeeded + if convResult.Config != nil && len(convResult.SuccessRules) > 0 { + outputPath := filepath.Join(c.outputDir, convResult.Config.Filename) + if err := os.WriteFile(outputPath, convResult.Config.Content, 0644); err != nil { + mu.Lock() + result.Errors[linter] = fmt.Errorf("failed to write file: %w", err) + mu.Unlock() + return + } - fmt.Fprintf(os.Stderr, "✓ Generated %s configuration: %s\n", linter, outputPath) + mu.Lock() + result.GeneratedFiles = append(result.GeneratedFiles, outputPath) + mu.Unlock() + + fmt.Fprintf(os.Stderr, "✓ Generated %s configuration: %s\n", linter, outputPath) + } }(linterName, rules) } // Wait for all goroutines to complete wg.Wait() - // Check if we have any successful conversions - if len(result.GeneratedFiles) == 0 && len(result.Errors) > 0 { + // Step 4.1: Update ruleToLinters mapping - remove failed linters and add llm-validator fallback + for linter, failedRuleIDs := range failedRulesPerLinter { + for _, ruleID := range failedRuleIDs { + // Remove failed linter from this rule's linters + currentLinters := ruleToLinters[ruleID] + updatedLinters := []string{} + for _, l := range currentLinters { + if l != linter { + updatedLinters = append(updatedLinters, l) + } + } + + // Add llm-validator as fallback if not already present + hasLLMValidator := false + for _, l := range updatedLinters { + if l == llmValidatorEngine { + hasLLMValidator = true + break + } + } + if !hasLLMValidator { + updatedLinters = append(updatedLinters, llmValidatorEngine) + } + + ruleToLinters[ruleID] = updatedLinters + } + } + + // Log fallback info + totalFallbacks := 0 + for _, failedRuleIDs := range failedRulesPerLinter { + totalFallbacks += len(failedRuleIDs) + } + if totalFallbacks > 0 { + fmt.Fprintf(os.Stderr, "ℹ️ %d rule(s) fell back to llm-validator due to conversion failures\n", totalFallbacks) + } + + // Check if we have any successful conversions (excluding llm-validator rules) + // Note: We don't fail if all rules went to llm-validator + if len(result.GeneratedFiles) == 0 && len(result.Errors) > 0 && totalFallbacks == 0 { return result, fmt.Errorf("all conversions failed") } @@ -250,7 +318,7 @@ func (c *Converter) Convert(ctx context.Context, userPolicy *schema.UserPolicy) } // routeRulesWithLLM uses LLM to determine which linters are appropriate for each rule -// Rules are processed in parallel for better performance +// Rules are processed in parallel with concurrency limited to CPU count func (c *Converter) routeRulesWithLLM(ctx context.Context, userPolicy *schema.UserPolicy) map[string][]schema.UserRule { type routeResult struct { rule schema.UserRule @@ -260,7 +328,18 @@ func (c *Converter) routeRulesWithLLM(ctx context.Context, userPolicy *schema.Us results := make(chan routeResult, len(userPolicy.Rules)) var wg sync.WaitGroup - // Process rules in parallel + // Limit concurrent LLM calls to prevent resource exhaustion + // Use CPU/4, minimum 2, maximum 4 to balance performance and stability + maxConcurrent := runtime.NumCPU() / 4 + if maxConcurrent < 2 { + maxConcurrent = 2 + } + if maxConcurrent > 4 { + maxConcurrent = 4 + } + sem := make(chan struct{}, maxConcurrent) + + // Process rules in parallel with concurrency limit for _, rule := range userPolicy.Rules { // Get languages for this rule languages := rule.Languages @@ -272,7 +351,11 @@ func (c *Converter) routeRulesWithLLM(ctx context.Context, userPolicy *schema.Us availableLinters := c.getAvailableLinters(languages) if len(availableLinters) == 0 { // No language-specific linters, use llm-validator - results <- routeResult{rule: rule, linters: []string{llmValidatorEngine}} + select { + case results <- routeResult{rule: rule, linters: []string{llmValidatorEngine}}: + case <-ctx.Done(): + continue + } continue } @@ -280,14 +363,31 @@ func (c *Converter) routeRulesWithLLM(ctx context.Context, userPolicy *schema.Us go func(r schema.UserRule, linters []string) { defer wg.Done() + // Acquire semaphore with context check + select { + case sem <- struct{}{}: + case <-ctx.Done(): + return + } + defer func() { <-sem }() + // Ask LLM which linters are appropriate for this rule selectedLinters := c.selectLintersForRule(ctx, r, linters) + // Send result with context check to prevent deadlock if len(selectedLinters) == 0 { // LLM couldn't map to any linter, use llm-validator - results <- routeResult{rule: r, linters: []string{llmValidatorEngine}} + select { + case results <- routeResult{rule: r, linters: []string{llmValidatorEngine}}: + case <-ctx.Done(): + return + } } else { - results <- routeResult{rule: r, linters: selectedLinters} + select { + case results <- routeResult{rule: r, linters: selectedLinters}: + case <-ctx.Done(): + return + } } }(rule, availableLinters) } @@ -405,8 +505,9 @@ Reason: Requires knowing which packages are "large"`, linterDescriptions, routin userPrompt := fmt.Sprintf("Rule: %s\nCategory: %s", rule.Say, rule.Category) - // Call LLM with medium complexity (needs some thought for linter selection) - response, err := c.llmClient.Request(systemPrompt, userPrompt).WithComplexity(llm.ComplexityLow).Execute(ctx) + // Call LLM + prompt := systemPrompt + "\n\n" + userPrompt + response, err := c.llmProvider.Execute(ctx, prompt, llm.JSON) if err != nil { fmt.Fprintf(os.Stderr, "Warning: LLM routing failed for rule %s: %v\n", rule.ID, err) return []string{} // Will fall back to llm-validator diff --git a/internal/envutil/env.go b/internal/envutil/env.go index 78b1f5e..f82e0e1 100644 --- a/internal/envutil/env.go +++ b/internal/envutil/env.go @@ -19,11 +19,6 @@ func GetAPIKey(keyName string) string { return LoadKeyFromEnvFile(filepath.Join(".sym", ".env"), keyName) } -// GetPolicyPath retrieves policy path from .sym/.env -func GetPolicyPath() string { - return LoadKeyFromEnvFile(filepath.Join(".sym", ".env"), "POLICY_PATH") -} - // LoadKeyFromEnvFile reads a specific key from .env file func LoadKeyFromEnvFile(envPath, key string) string { file, err := os.Open(envPath) diff --git a/internal/llm/README.md b/internal/llm/README.md index 94bcc6f..c4bfdb8 100644 --- a/internal/llm/README.md +++ b/internal/llm/README.md @@ -1,8 +1,181 @@ -# llm +# LLM Package -OpenAI API 클라이언트를 제공합니다. +Unified interface for LLM providers. -LLM 기반 추론 및 검증을 지원하며 정책 변환 시 자연어 규칙 해석에 활용됩니다. +## File Structure -**사용자**: cmd, converter, engine, mcp, validator -**의존성**: 없음 +``` +internal/llm/ +├── llm.go # Provider, RawProvider interface, Config, ResponseFormat +├── registry.go # RegisterProvider, New, ListProviders +├── wrapper.go # parsedProvider (auto parsing wrapper) +├── config.go # LoadConfig +├── parser.go # Response parsing (private) +├── DESIGN.md # Architecture documentation +├── claudecode/ # Claude Code CLI provider +├── geminicli/ # Gemini CLI provider +└── openaiapi/ # OpenAI API provider +``` + +## Usage + +```go +import "github.com/DevSymphony/sym-cli/internal/llm" + +// 1. Load config +cfg := llm.LoadConfig() + +// 2. Create provider +provider, err := llm.New(cfg) +if err != nil { + return err // CLI not installed or API key missing +} + +// 3. Execute prompt +response, err := provider.Execute(ctx, prompt, llm.JSON) +``` + +### Configuration + +Config file: `.sym/config.json` + +```json +{ + "llm": { + "provider": "claudecode", + "model": "sonnet" + } +} +``` + +For OpenAI API, also add API key to `.sym/.env`: + +```bash +OPENAI_API_KEY=sk-... +``` + +### Response Format + +| Format | Description | +|--------|-------------| +| `llm.Text` | Return raw text | +| `llm.JSON` | Extract JSON from LLM response | +| `llm.XML` | Extract XML from LLM response | + +`llm.JSON` and `llm.XML` automatically extract structured data when LLM returns JSON/XML with preamble text. + +## Provider List + +| Name | Type | Default Model | Installation | +|------|------|---------------|--------------| +| `claudecode` | CLI | sonnet | `npm i -g @anthropic-ai/claude-cli` | +| `geminicli` | CLI | gemini-2.5-flash | `npm i -g @google/gemini-cli` | +| `openaiapi` | API | gpt-4o-mini | Set `OPENAI_API_KEY` env var | + +### Check Provider Status + +```go +providers := llm.ListProviders() +for _, p := range providers { + fmt.Printf("%s: available=%v\n", p.Name, p.Available) +} +``` + +## Adding New Provider + +### Step 1: Create Directory + +``` +internal/llm// +└── provider.go +``` + +### Step 2: Implement RawProvider Interface + +```go +package myprovider + +import ( + "context" + "github.com/DevSymphony/sym-cli/internal/llm" +) + +type Provider struct { + model string +} + +// Compile-time check: Provider must implement RawProvider interface +var _ llm.RawProvider = (*Provider)(nil) + +func (p *Provider) Name() string { + return "myprovider" +} + +func (p *Provider) ExecuteRaw(ctx context.Context, prompt string, format llm.ResponseFormat) (string, error) { + // Execute prompt and return raw response + response := callLLM(prompt) + return response, nil // No manual parsing needed +} + +func (p *Provider) Close() error { + return nil +} +``` + +> **Note**: The `var _ llm.RawProvider = (*Provider)(nil)` line is a compile-time check that ensures `Provider` implements the `RawProvider` interface. If any method is missing or has the wrong signature, the code will fail to compile. + +### Step 3: Register in init() + +```go +func init() { + llm.RegisterProvider("myprovider", newProvider, llm.ProviderInfo{ + Name: "myprovider", + DisplayName: "My Provider", + DefaultModel: "default-model-v1", + Available: checkAvailability(), // Check CLI exists or API key set + Path: cliPath, // CLI path (empty for API) + Models: []llm.ModelInfo{ + {ID: "model-v1", DisplayName: "Model V1", Description: "Standard model", Recommended: true}, + {ID: "model-v2", DisplayName: "Model V2", Description: "Advanced model", Recommended: false}, + }, + APIKey: llm.APIKeyConfig{ + Required: true, // Set false for CLI-based providers + EnvVarName: "MY_PROVIDER_API_KEY", // Environment variable name + Prefix: "mp-", // Expected prefix for validation (optional) + }, + }) +} + +func newProvider(cfg llm.Config) (llm.RawProvider, error) { + // Return error if CLI not installed or API key missing + if !isAvailable() { + return nil, fmt.Errorf("provider not available") + } + + model := cfg.Model + if model == "" { + model = "default-model-v1" + } + + return &Provider{model: model}, nil +} +``` + +### Step 4: Add Import to Bootstrap + +```go +// internal/bootstrap/providers.go +import ( + _ "github.com/DevSymphony/sym-cli/internal/llm/myprovider" +) +``` + +### Key Rules + +- Add compile-time check: `var _ llm.RawProvider = (*Provider)(nil)` +- Implement `RawProvider.ExecuteRaw()` - parsing is handled automatically by wrapper +- Return clear error message if CLI not installed or API key missing +- Check availability in init() to set ProviderInfo.Available +- Define `Models` with at least one model marked as `Recommended: true` +- Set `APIKey.Required: false` for CLI-based providers (they handle auth internally) +- Refer to existing providers (claudecode, openaiapi) for patterns diff --git a/internal/llm/claudecode/provider.go b/internal/llm/claudecode/provider.go new file mode 100644 index 0000000..4d17597 --- /dev/null +++ b/internal/llm/claudecode/provider.go @@ -0,0 +1,125 @@ +// Package claudecode provides the Claude Code CLI LLM provider. +package claudecode + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" + "time" + + "github.com/DevSymphony/sym-cli/internal/llm" +) + +const ( + providerName = "claudecode" + displayName = "Claude Code CLI" + command = "claude" + defaultModel = "sonnet" // Claude CLI accepts short aliases: sonnet, opus, haiku + defaultTimeout = 120 * time.Second +) + +func init() { + // Check if CLI is available + path, _ := exec.LookPath(command) + available := path != "" + + llm.RegisterProvider(providerName, newProvider, llm.ProviderInfo{ + Name: providerName, + DisplayName: displayName, + DefaultModel: defaultModel, + Available: available, + Path: path, + Models: []llm.ModelInfo{ + {ID: "sonnet", DisplayName: "sonnet", Description: "Balanced performance and speed", Recommended: true}, + {ID: "opus", DisplayName: "opus", Description: "Highest capability", Recommended: false}, + {ID: "haiku", DisplayName: "haiku", Description: "Fast and efficient", Recommended: false}, + }, + APIKey: llm.APIKeyConfig{Required: false}, + }) +} + +// Provider implements llm.RawProvider for Claude Code CLI. +type Provider struct { + model string + timeout time.Duration + verbose bool + cliPath string +} + +// Compile-time check: Provider must implement RawProvider interface +var _ llm.RawProvider = (*Provider)(nil) + +// newProvider creates a new Claude Code provider. +// Returns error if Claude CLI is not installed. +func newProvider(cfg llm.Config) (llm.RawProvider, error) { + path, err := exec.LookPath(command) + if err != nil { + return nil, fmt.Errorf("claude CLI not installed: run 'npm install -g @anthropic-ai/claude-cli' to install") + } + + model := cfg.Model + if model == "" { + model = defaultModel + } + + return &Provider{ + model: model, + timeout: defaultTimeout, + verbose: cfg.Verbose, + cliPath: path, + }, nil +} + +func (p *Provider) Name() string { + return providerName +} + +func (p *Provider) ExecuteRaw(ctx context.Context, prompt string, format llm.ResponseFormat) (string, error) { + args := []string{"-p", prompt, "--output-format", "text"} + + // MCP 서버 로딩 비활성화: 재귀 호출 방지 + // Symphony가 claude CLI를 호출할 때 MCP가 다시 로드되면 무한 루프 발생 + // --mcp-config는 JSON 문자열도 지원함 (파일 경로 외에) + args = append(args, "--strict-mcp-config", "--mcp-config", `{"mcpServers":{}}`) + + if p.model != "" { + args = append(args, "--model", p.model) + } + + if p.verbose { + fmt.Fprintf(os.Stderr, "[claudecode] Model: %s, Prompt: %d chars\n", p.model, len(prompt)) + } + + cmdCtx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + cmd := exec.CommandContext(cmdCtx, p.cliPath, args...) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + if cmdCtx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("claude CLI timed out after %v", p.timeout) + } + return "", fmt.Errorf("claude CLI failed: %w\nstderr: %s", err, stderr.String()) + } + + response := strings.TrimSpace(stdout.String()) + + if p.verbose { + fmt.Fprintf(os.Stderr, "[claudecode] Response: %d chars\n", len(response)) + } + + return response, nil +} + +// Close is a no-op for CLI-based providers. +func (p *Provider) Close() error { + return nil +} diff --git a/internal/llm/client.go b/internal/llm/client.go deleted file mode 100644 index 800f190..0000000 --- a/internal/llm/client.go +++ /dev/null @@ -1,327 +0,0 @@ -package llm - -import ( - "context" - "fmt" - "os" - "time" - - "github.com/DevSymphony/sym-cli/internal/envutil" - "github.com/DevSymphony/sym-cli/internal/llm/engine" - mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" -) - -const ( - defaultMaxTokens = 1000 - defaultTemperature = 1.0 - defaultTimeout = 60 * time.Second -) - -const ( - // ModeAPI uses OpenAI API. - ModeAPI = engine.ModeAPI - // ModeMCP uses MCP sampling. - ModeMCP = engine.ModeMCP - // ModeCLI uses CLI engine. - ModeCLI = engine.ModeCLI - // ModeAuto automatically selects the best available engine. - ModeAuto = engine.ModeAuto -) - -// Client represents an LLM client with fallback chain support. -type Client struct { - // Engine configuration - config *LLMConfig - mode engine.Mode - engines []engine.LLMEngine - mcpSession *mcpsdk.ServerSession - - // Default request parameters - maxTokens int - temperature float64 - verbose bool -} - -// ClientOption is a functional option for configuring the client. -type ClientOption func(*Client) - -// WithMaxTokens sets the default max tokens. -func WithMaxTokens(maxTokens int) ClientOption { - return func(c *Client) { c.maxTokens = maxTokens } -} - -// WithTemperature sets the default temperature. -func WithTemperature(temperature float64) ClientOption { - return func(c *Client) { c.temperature = temperature } -} - -// WithTimeout sets the HTTP client timeout (for API engine). -func WithTimeout(_ time.Duration) ClientOption { - // Note: This is handled by individual engines now - return func(_ *Client) {} -} - -// WithVerbose enables verbose logging. -func WithVerbose(verbose bool) ClientOption { - return func(c *Client) { c.verbose = verbose } -} - -// WithMCPSession sets the MCP session for MCP mode. -func WithMCPSession(session *mcpsdk.ServerSession) ClientOption { - return func(c *Client) { - c.mcpSession = session - c.mode = engine.ModeMCP - } -} - -// WithConfig sets a custom LLM configuration. -func WithConfig(cfg *LLMConfig) ClientOption { - return func(c *Client) { - if cfg == nil { - return - } - c.config = cfg - if mode := cfg.GetEffectiveBackend(); mode != "" { - c.mode = mode - } - } -} - -// WithMode sets the preferred engine mode. -func WithMode(mode engine.Mode) ClientOption { - return func(c *Client) { - c.mode = mode - } -} - -// NewClient creates a new LLM client. -func NewClient(opts ...ClientOption) *Client { - // Load default config - config := LoadLLMConfig() - - apiKey := envutil.GetAPIKey("OPENAI_API_KEY") - config.APIKey = apiKey - - client := &Client{ - config: config, - mode: config.GetEffectiveBackend(), - maxTokens: defaultMaxTokens, - temperature: defaultTemperature, - verbose: false, - } - - // Apply options - for _, opt := range opts { - opt(client) - } - - // Initialize engine chain - client.initEngines() - - return client -} - -// initEngines initializes the engine fallback chain based on configuration. -func (c *Client) initEngines() { - c.engines = []engine.LLMEngine{} - - // Determine which engines to include based on mode - switch c.mode { - case engine.ModeMCP: - c.addMCPEngine() - case engine.ModeCLI: - c.addCLIEngine() - case engine.ModeAPI: - c.addAPIEngine() - case engine.ModeAuto: - fallthrough - default: - // add all available engines - c.addMCPEngine() - c.addCLIEngine() - c.addAPIEngine() - } -} - -// addMCPEngine adds MCP engine if session is available. -func (c *Client) addMCPEngine() { - if c.mcpSession != nil { - eng := engine.NewMCPEngine(c.mcpSession, engine.WithMCPVerbose(c.verbose)) - c.engines = append(c.engines, eng) - } -} - -// addCLIEngine adds CLI engine if configured. -func (c *Client) addCLIEngine() { - if c.config.CLI != "" { - providerType := engine.CLIProviderType(c.config.CLI) - if !providerType.IsValid() { - return - } - - opts := []engine.CLIEngineOption{} - - if c.config.CLIPath != "" { - opts = append(opts, engine.WithCLIPath(c.config.CLIPath)) - } - - if c.config.Model != "" { - opts = append(opts, engine.WithCLIModel(c.config.Model)) - } - - if c.config.LargeModel != "" { - opts = append(opts, engine.WithCLILargeModel(c.config.LargeModel)) - } - - if c.verbose { - opts = append(opts, engine.WithCLIVerbose(true)) - } - - eng, err := engine.NewCLIEngine(providerType, opts...) - if err == nil && eng.IsAvailable() { - c.engines = append(c.engines, eng) - } - } -} - -// addAPIEngine adds API engine if key is available. -func (c *Client) addAPIEngine() { - apiKey := c.config.GetAPIKey() - if apiKey != "" { - eng := engine.NewAPIEngine(apiKey, engine.WithAPIVerbose(c.verbose)) - c.engines = append(c.engines, eng) - } -} - -// Request creates a new request builder. -// -// Usage: -// -// client.Request(system, user).Execute(ctx) // default complexity -// client.Request(system, user).WithComplexity(llm.ComplexityMedium).Execute(ctx) // higher complexity -// client.Request(system, user).WithComplexity(engine.ComplexityHigh).Execute(ctx) // explicit complexity -// client.Request(system, user).WithMaxTokens(2000).Execute(ctx) // custom tokens -func (c *Client) Request(systemPrompt, userPrompt string) *RequestBuilder { - return &RequestBuilder{ - client: c, - system: systemPrompt, - user: userPrompt, - maxTokens: c.maxTokens, - temperature: c.temperature, - complexity: engine.ComplexityLow, - } -} - -// GetActiveEngine returns the first available engine. -func (c *Client) GetActiveEngine() engine.LLMEngine { - for _, e := range c.engines { - if e.IsAvailable() { - return e - } - } - return nil -} - -// GetEngines returns all configured engines. -func (c *Client) GetEngines() []engine.LLMEngine { - return c.engines -} - -// GetConfig returns the LLM configuration. -func (c *Client) GetConfig() *LLMConfig { - return c.config -} - -// CheckAvailability checks if any LLM engine is available. -func (c *Client) CheckAvailability(ctx context.Context) error { - eng := c.GetActiveEngine() - if eng == nil { - return fmt.Errorf("no available LLM engine") - } - - // For API engine, do a simple test request - if eng.Name() == "openai-api" { - _, err := c.Request("You are a test assistant.", "Say 'OK'").Execute(ctx) - if err != nil { - return fmt.Errorf("OpenAI API not available: %w", err) - } - } - - return nil -} - -// RequestBuilder builds and executes LLM requests with chain methods. -type RequestBuilder struct { - client *Client - system string - user string - maxTokens int - temperature float64 - complexity engine.Complexity -} - -// WithComplexity sets the task complexity hint (engine-agnostic). -func (r *RequestBuilder) WithComplexity(c engine.Complexity) *RequestBuilder { - r.complexity = c - return r -} - -// WithMaxTokens sets max tokens for this request. -func (r *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder { - r.maxTokens = tokens - return r -} - -// WithTemperature sets temperature for this request. -func (r *RequestBuilder) WithTemperature(temp float64) *RequestBuilder { - r.temperature = temp - return r -} - -// Execute sends the request and returns the response. -func (r *RequestBuilder) Execute(ctx context.Context) (string, error) { - req := &engine.Request{ - SystemPrompt: r.system, - UserPrompt: r.user, - MaxTokens: r.maxTokens, - Temperature: r.temperature, - Complexity: r.complexity, - } - - return r.client.executeWithFallback(ctx, req) -} - -// executeWithFallback tries engines in priority order. -func (c *Client) executeWithFallback(ctx context.Context, req *engine.Request) (string, error) { - var lastErr error - - for _, eng := range c.engines { - if !eng.IsAvailable() { - continue - } - - result, err := eng.Execute(ctx, req) - if err == nil { - return result, nil - } - - lastErr = err - if c.verbose { - fmt.Fprintf(os.Stderr, "⚠️ %s failed: %v, trying next engine...\n", eng.Name(), err) - } - } - - if lastErr != nil { - return "", fmt.Errorf("all engines failed, last error: %w", lastErr) - } - - return "", fmt.Errorf("no available LLM engine configured") -} - -// ExecuteDirect executes request on a specific engine without fallback. -func (c *Client) ExecuteDirect(ctx context.Context, eng engine.LLMEngine, req *engine.Request) (string, error) { - if !eng.IsAvailable() { - return "", fmt.Errorf("engine %s is not available", eng.Name()) - } - return eng.Execute(ctx, req) -} diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go deleted file mode 100644 index c462d48..0000000 --- a/internal/llm/client_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package llm - -import ( - "testing" - - "github.com/DevSymphony/sym-cli/internal/llm/engine" - "github.com/stretchr/testify/assert" -) - -func TestNewClient(t *testing.T) { - t.Run("default_config", func(t *testing.T) { - client := NewClient() - assert.NotNil(t, client) - assert.NotNil(t, client.GetConfig()) - }) - - t.Run("with_options_and_config", func(t *testing.T) { - cfg := &LLMConfig{ - Backend: engine.ModeAPI, - APIKey: "sk-test", - } - client := NewClient(WithConfig(cfg), WithVerbose(true)) - assert.NotNil(t, client) - assert.Equal(t, engine.ModeAPI, client.config.Backend) - }) - - t.Run("with_mode_option", func(t *testing.T) { - client := NewClient(WithMode(engine.ModeAPI)) - assert.NotNil(t, client) - }) -} - -func TestClient_GetActiveEngine(t *testing.T) { - t.Run("with API engine", func(t *testing.T) { - cfg := &LLMConfig{ - Backend: engine.ModeAPI, - APIKey: "sk-test", - } - client := NewClient(WithConfig(cfg)) - eng := client.GetActiveEngine() - assert.NotNil(t, eng) - assert.Equal(t, "openai-api", eng.Name()) - }) - - t.Run("no engine available", func(t *testing.T) { - cfg := &LLMConfig{ - Backend: engine.ModeAPI, - // No API key - } - client := NewClient(WithConfig(cfg)) - eng := client.GetActiveEngine() - assert.Nil(t, eng) - }) -} - -func TestRequestBuilder(t *testing.T) { - client := NewClient() - - t.Run("basic request", func(t *testing.T) { - builder := client.Request("system", "user") - assert.NotNil(t, builder) - }) - - t.Run("with complexity", func(t *testing.T) { - builder := client.Request("system", "user"). - WithComplexity(engine.ComplexityHigh) - assert.NotNil(t, builder) - }) - - t.Run("with max tokens", func(t *testing.T) { - builder := client.Request("system", "user"). - WithMaxTokens(2000) - assert.NotNil(t, builder) - }) - - t.Run("with temperature", func(t *testing.T) { - builder := client.Request("system", "user"). - WithTemperature(0.7) - assert.NotNil(t, builder) - }) - - t.Run("chained options", func(t *testing.T) { - builder := client.Request("system", "user"). - WithComplexity(engine.ComplexityMedium). - WithMaxTokens(1500). - WithTemperature(0.8) - assert.NotNil(t, builder) - }) -} - -func TestModeConstants(t *testing.T) { - // Verify backward compatibility - assert.Equal(t, engine.ModeAPI, ModeAPI) - assert.Equal(t, engine.ModeMCP, ModeMCP) - assert.Equal(t, engine.ModeCLI, ModeCLI) - assert.Equal(t, engine.ModeAuto, ModeAuto) -} diff --git a/internal/llm/complexity.go b/internal/llm/complexity.go deleted file mode 100644 index 243d239..0000000 --- a/internal/llm/complexity.go +++ /dev/null @@ -1,17 +0,0 @@ -package llm - -import "github.com/DevSymphony/sym-cli/internal/llm/engine" - -// Complexity re-exports engine.Complexity for backward compatibility. -type Complexity = engine.Complexity - -const ( - // ComplexityMinimal is for trivial lookups. - ComplexityMinimal Complexity = engine.ComplexityMinimal - // ComplexityLow is for simple transformations. - ComplexityLow Complexity = engine.ComplexityLow - // ComplexityMedium is for moderate reasoning. - ComplexityMedium Complexity = engine.ComplexityMedium - // ComplexityHigh is for complex reasoning. - ComplexityHigh Complexity = engine.ComplexityHigh -) diff --git a/internal/llm/config.go b/internal/llm/config.go index fe6eca4..1967d13 100644 --- a/internal/llm/config.go +++ b/internal/llm/config.go @@ -1,366 +1,34 @@ package llm import ( - "bufio" "fmt" - "os" - "path/filepath" - "strings" - "github.com/DevSymphony/sym-cli/internal/llm/engine" + "github.com/DevSymphony/sym-cli/internal/config" ) -const ( - // Default .sym/.env file location relative to repo root - defaultEnvFile = ".sym/.env" - - // Environment variable keys - envKeyLLMBackend = "LLM_BACKEND" - envKeyLLMCLI = "LLM_CLI" - envKeyLLMCLIPath = "LLM_CLI_PATH" - envKeyLLMModel = "LLM_MODEL" - envKeyLLMLarge = "LLM_LARGE_MODEL" - envKeyAPIKey = "OPENAI_API_KEY" -) - -// LLMConfig holds LLM engine configuration. -type LLMConfig struct { - // Backend is the preferred engine mode (auto, mcp, cli, api). - Backend engine.Mode `json:"backend"` - - // CLI is the CLI provider type (claude, gemini). - CLI string `json:"cli"` - - // CLIPath is a custom path to the CLI executable (optional). - CLIPath string `json:"cli_path"` - - // Model is the default model name for CLI engine. - Model string `json:"model"` - - // LargeModel is the model for high complexity tasks (optional). - LargeModel string `json:"large_model"` - - // APIKey is loaded from environment (not saved to config). - APIKey string `json:"-"` -} - -// DefaultLLMConfig returns the default configuration. -func DefaultLLMConfig() *LLMConfig { - return &LLMConfig{ - Backend: engine.ModeAuto, - CLI: "", - CLIPath: "", - Model: "", - } -} - -// LoadLLMConfig loads LLM configuration from .sym/.env file and environment. -func LoadLLMConfig() *LLMConfig { - cfg := DefaultLLMConfig() - - // Load from .sym/.env file first - envPath := defaultEnvFile - loadConfigFromEnvFile(envPath, cfg) - - // Override with system environment variables - loadConfigFromEnv(cfg) - - return cfg -} - -// LoadLLMConfigFromDir loads LLM configuration from a specific directory. -func LoadLLMConfigFromDir(dir string) *LLMConfig { - cfg := DefaultLLMConfig() - - // Load from .env file in the specified directory - envPath := filepath.Join(dir, ".env") - loadConfigFromEnvFile(envPath, cfg) - - // Override with system environment variables - loadConfigFromEnv(cfg) - - return cfg -} - -// loadConfigFromEnvFile reads config values from .env file. -func loadConfigFromEnvFile(envPath string, cfg *LLMConfig) { - file, err := os.Open(envPath) - if err != nil { - return // File doesn't exist, use defaults - } - defer func() { _ = file.Close() }() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - - // Skip comments and empty lines - if len(line) == 0 || line[0] == '#' { - continue - } - - // Parse key=value - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - switch key { - case envKeyLLMBackend: - if engine.Mode(value).IsValid() { - cfg.Backend = engine.Mode(value) - } - case envKeyLLMCLI: - cfg.CLI = value - case envKeyLLMCLIPath: - cfg.CLIPath = value - case envKeyLLMModel: - cfg.Model = value - case envKeyLLMLarge: - cfg.LargeModel = value - case envKeyAPIKey: - cfg.APIKey = value - } - } -} - -// loadConfigFromEnv loads config from system environment variables. -func loadConfigFromEnv(cfg *LLMConfig) { - if backend := os.Getenv(envKeyLLMBackend); backend != "" { - if engine.Mode(backend).IsValid() { - cfg.Backend = engine.Mode(backend) - } - } - - if cli := os.Getenv(envKeyLLMCLI); cli != "" { - cfg.CLI = cli - } - - if cliPath := os.Getenv(envKeyLLMCLIPath); cliPath != "" { - cfg.CLIPath = cliPath - } - - if model := os.Getenv(envKeyLLMModel); model != "" { - cfg.Model = model - } - - if large := os.Getenv(envKeyLLMLarge); large != "" { - cfg.LargeModel = large - } - - if apiKey := os.Getenv(envKeyAPIKey); apiKey != "" { - cfg.APIKey = apiKey - } -} - -// SaveLLMConfig saves LLM configuration to .sym/.env file. -func SaveLLMConfig(cfg *LLMConfig) error { - return SaveLLMConfigToDir(".sym", cfg) -} - -// SaveLLMConfigToDir saves LLM configuration to a specific directory. -func SaveLLMConfigToDir(dir string, cfg *LLMConfig) error { - // Ensure directory exists - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - - envPath := filepath.Join(dir, ".env") - - // Read existing content - existingLines, existingKeys := readExistingEnvFile(envPath) - - // Prepare new values - newValues := map[string]string{} - - if cfg.Backend != "" && cfg.Backend != engine.ModeAuto { - newValues[envKeyLLMBackend] = string(cfg.Backend) - } - - if cfg.CLI != "" { - newValues[envKeyLLMCLI] = cfg.CLI - } - - if cfg.CLIPath != "" { - newValues[envKeyLLMCLIPath] = cfg.CLIPath - } - - if cfg.Model != "" { - newValues[envKeyLLMModel] = cfg.Model - } - - if cfg.LargeModel != "" { - newValues[envKeyLLMLarge] = cfg.LargeModel - } - - // Build output lines - var outputLines []string - - // Update existing lines - for _, line := range existingLines { - trimmed := strings.TrimSpace(line) - - // Keep comments and empty lines - if trimmed == "" || strings.HasPrefix(trimmed, "#") { - outputLines = append(outputLines, line) - continue - } - - // Parse key - parts := strings.SplitN(trimmed, "=", 2) - if len(parts) != 2 { - outputLines = append(outputLines, line) - continue - } - - key := strings.TrimSpace(parts[0]) - - // Check if we have a new value for this key - if newValue, ok := newValues[key]; ok { - outputLines = append(outputLines, fmt.Sprintf("%s=%s", key, newValue)) - delete(newValues, key) // Mark as processed - } else { - outputLines = append(outputLines, line) - } - } - - // Add LLM config section header if needed - hasLLMSection := false - for key := range existingKeys { - if strings.HasPrefix(key, "LLM_") { - hasLLMSection = true - break - } - } - - // Add new keys that weren't in the file - if len(newValues) > 0 { - if !hasLLMSection { - outputLines = append(outputLines, "", "# LLM Backend Configuration") - } - - for key, value := range newValues { - outputLines = append(outputLines, fmt.Sprintf("%s=%s", key, value)) - } - } - - // Write to file - content := strings.Join(outputLines, "\n") - if !strings.HasSuffix(content, "\n") { - content += "\n" - } - - return os.WriteFile(envPath, []byte(content), 0600) -} - -// readExistingEnvFile reads existing .env file content. -func readExistingEnvFile(envPath string) ([]string, map[string]bool) { - var lines []string - keys := make(map[string]bool) - - file, err := os.Open(envPath) - if err != nil { - return lines, keys - } - defer func() { _ = file.Close() }() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - lines = append(lines, line) - - // Track existing keys - trimmed := strings.TrimSpace(line) - if len(trimmed) > 0 && !strings.HasPrefix(trimmed, "#") { - parts := strings.SplitN(trimmed, "=", 2) - if len(parts) == 2 { - keys[strings.TrimSpace(parts[0])] = true - } - } - } - - return lines, keys -} - -// GetAPIKey returns the API key from config or environment. -func (c *LLMConfig) GetAPIKey() string { - if c.APIKey != "" { - return c.APIKey - } - return os.Getenv(envKeyAPIKey) -} - -// HasCLI returns true if CLI is configured. -func (c *LLMConfig) HasCLI() bool { - return c.CLI != "" -} - -// HasAPIKey returns true if API key is available. -func (c *LLMConfig) HasAPIKey() bool { - return c.GetAPIKey() != "" -} - -// GetEffectiveBackend returns the actual engine to use based on availability. -func (c *LLMConfig) GetEffectiveBackend() engine.Mode { - if c.Backend != engine.ModeAuto { - return c.Backend - } - - // Auto mode: prefer CLI if available, then API - if c.HasCLI() { - return engine.ModeCLI - } - - if c.HasAPIKey() { - return engine.ModeAPI - } - - return engine.ModeAuto -} - // Validate checks if the configuration is valid. -func (c *LLMConfig) Validate() error { - if c.Backend != "" && !c.Backend.IsValid() { - return fmt.Errorf("invalid engine mode: %s", c.Backend) - } - - if c.CLI != "" && !engine.CLIProviderType(c.CLI).IsValid() { - return fmt.Errorf("unsupported CLI provider: %s", c.CLI) +func (c *Config) Validate() error { + if c.Provider == "" { + return fmt.Errorf("provider is required (configure in .sym/config.json)") } - return nil } -// String returns a human-readable representation of the config. -func (c *LLMConfig) String() string { - var parts []string - - parts = append(parts, fmt.Sprintf("Backend: %s", c.Backend)) - - if c.CLI != "" { - parts = append(parts, fmt.Sprintf("CLI: %s", c.CLI)) - } - - if c.CLIPath != "" { - parts = append(parts, fmt.Sprintf("CLI Path: %s", c.CLIPath)) - } - - if c.Model != "" { - parts = append(parts, fmt.Sprintf("Model: %s", c.Model)) - } +// LoadConfig loads configuration from .sym/config.json. +func LoadConfig() Config { + return LoadConfigFromDir("") +} - if c.LargeModel != "" { - parts = append(parts, fmt.Sprintf("Large Model: %s", c.LargeModel)) - } +// LoadConfigFromDir loads configuration from .sym/config.json. +// Note: API keys are handled by individual providers (e.g., openaiapi uses envutil.GetAPIKey). +func LoadConfigFromDir(_ string) Config { + cfg := Config{} - if c.HasAPIKey() { - parts = append(parts, "API Key: configured") - } else { - parts = append(parts, "API Key: not set") + // Load from .sym/config.json + if projectCfg, err := config.LoadProjectConfig(); err == nil { + cfg.Provider = projectCfg.LLM.Provider + cfg.Model = projectCfg.LLM.Model } - return strings.Join(parts, ", ") + return cfg } diff --git a/internal/llm/config_test.go b/internal/llm/config_test.go deleted file mode 100644 index b5bc4e1..0000000 --- a/internal/llm/config_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package llm - -import ( - "os" - "path/filepath" - "testing" - - "github.com/DevSymphony/sym-cli/internal/llm/engine" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDefaultLLMConfig(t *testing.T) { - cfg := DefaultLLMConfig() - - assert.Equal(t, engine.ModeAuto, cfg.Backend) - assert.Empty(t, cfg.CLI) - assert.Empty(t, cfg.CLIPath) - assert.Empty(t, cfg.Model) -} - -func TestLLMConfig_HasCLI(t *testing.T) { - t.Run("with CLI", func(t *testing.T) { - cfg := &LLMConfig{CLI: "claude"} - assert.True(t, cfg.HasCLI()) - }) - - t.Run("without CLI", func(t *testing.T) { - cfg := &LLMConfig{} - assert.False(t, cfg.HasCLI()) - }) -} - -func TestLLMConfig_HasAPIKey(t *testing.T) { - t.Run("with API key in config", func(t *testing.T) { - cfg := &LLMConfig{APIKey: "sk-test"} - assert.True(t, cfg.HasAPIKey()) - }) - - t.Run("without API key", func(t *testing.T) { - cfg := &LLMConfig{} - assert.False(t, cfg.HasAPIKey()) - }) -} - -func TestLLMConfig_GetEffectiveBackend(t *testing.T) { - t.Run("explicit mode", func(t *testing.T) { - cfg := &LLMConfig{Backend: engine.ModeCLI} - assert.Equal(t, engine.ModeCLI, cfg.GetEffectiveBackend()) - }) - - t.Run("auto with CLI", func(t *testing.T) { - cfg := &LLMConfig{Backend: engine.ModeAuto, CLI: "claude"} - assert.Equal(t, engine.ModeCLI, cfg.GetEffectiveBackend()) - }) - - t.Run("auto with API key", func(t *testing.T) { - cfg := &LLMConfig{Backend: engine.ModeAuto, APIKey: "sk-test"} - assert.Equal(t, engine.ModeAPI, cfg.GetEffectiveBackend()) - }) - - t.Run("auto with nothing", func(t *testing.T) { - cfg := &LLMConfig{Backend: engine.ModeAuto} - assert.Equal(t, engine.ModeAuto, cfg.GetEffectiveBackend()) - }) -} - -func TestLLMConfig_Validate(t *testing.T) { - t.Run("valid config", func(t *testing.T) { - cfg := &LLMConfig{ - Backend: engine.ModeAuto, - CLI: "claude", - } - assert.NoError(t, cfg.Validate()) - }) - - t.Run("invalid backend", func(t *testing.T) { - cfg := &LLMConfig{Backend: engine.Mode("invalid")} - assert.Error(t, cfg.Validate()) - }) - - t.Run("invalid CLI provider", func(t *testing.T) { - cfg := &LLMConfig{CLI: "invalid-cli"} - assert.Error(t, cfg.Validate()) - }) - - t.Run("empty config is valid", func(t *testing.T) { - cfg := &LLMConfig{} - assert.NoError(t, cfg.Validate()) - }) -} - -func TestLLMConfig_String(t *testing.T) { - cfg := &LLMConfig{ - Backend: engine.ModeAuto, - CLI: "claude", - Model: "claude-3-opus", - } - - str := cfg.String() - assert.Contains(t, str, "Backend: auto") - assert.Contains(t, str, "CLI: claude") - assert.Contains(t, str, "Model: claude-3-opus") -} - -func TestSaveLLMConfig(t *testing.T) { - tmpDir := t.TempDir() - - cfg := &LLMConfig{ - Backend: engine.ModeCLI, - CLI: "claude", - Model: "claude-3-opus", - LargeModel: "claude-3-opus", - } - - err := SaveLLMConfigToDir(tmpDir, cfg) - require.NoError(t, err) - - // Verify file was created - envPath := filepath.Join(tmpDir, ".env") - _, err = os.Stat(envPath) - require.NoError(t, err) - - // Read and verify content - content, err := os.ReadFile(envPath) - require.NoError(t, err) - - assert.Contains(t, string(content), "LLM_BACKEND=cli") - assert.Contains(t, string(content), "LLM_CLI=claude") - assert.Contains(t, string(content), "LLM_MODEL=claude-3-opus") -} - -func TestLoadLLMConfigFromDir(t *testing.T) { - tmpDir := t.TempDir() - - // Create .env file - envContent := `# Test config -LLM_BACKEND=cli -LLM_CLI=gemini -LLM_MODEL=gemini-pro -` - envPath := filepath.Join(tmpDir, ".env") - err := os.WriteFile(envPath, []byte(envContent), 0600) - require.NoError(t, err) - - cfg := LoadLLMConfigFromDir(tmpDir) - - assert.Equal(t, engine.ModeCLI, cfg.Backend) - assert.Equal(t, "gemini", cfg.CLI) - assert.Equal(t, "gemini-pro", cfg.Model) -} - -func TestLoadLLMConfigFromDir_NonExistent(t *testing.T) { - cfg := LoadLLMConfigFromDir("/nonexistent/path") - - // Should return defaults - assert.Equal(t, engine.ModeAuto, cfg.Backend) - assert.Empty(t, cfg.CLI) -} - diff --git a/internal/llm/engine/api.go b/internal/llm/engine/api.go deleted file mode 100644 index 6c4227f..0000000 --- a/internal/llm/engine/api.go +++ /dev/null @@ -1,247 +0,0 @@ -package engine - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "time" -) - -const ( - openAIAPIURL = "https://api.openai.com/v1/chat/completions" - defaultAPIFastModel = "gpt-4o-mini" - defaultAPIPowerModel = "gpt-5-mini" - defaultAPITimeout = 60 * time.Second - defaultAPIMaxTokens = 1000 - defaultAPITemperature = 1.0 -) - -// APIEngine implements LLMEngine interface for OpenAI API. -type APIEngine struct { - apiKey string - fastModel string - powerModel string - httpClient *http.Client - maxTokens int - temperature float64 - verbose bool -} - -// APIEngineOption is a functional option for APIEngine. -type APIEngineOption func(*APIEngine) - -// WithAPIFastModel sets the fast model. -func WithAPIFastModel(model string) APIEngineOption { - return func(e *APIEngine) { e.fastModel = model } -} - -// WithAPIPowerModel sets the power model. -func WithAPIPowerModel(model string) APIEngineOption { - return func(e *APIEngine) { e.powerModel = model } -} - -// WithAPITimeout sets the HTTP client timeout. -func WithAPITimeout(timeout time.Duration) APIEngineOption { - return func(e *APIEngine) { e.httpClient.Timeout = timeout } -} - -// WithAPIVerbose enables verbose logging. -func WithAPIVerbose(verbose bool) APIEngineOption { - return func(e *APIEngine) { e.verbose = verbose } -} - -// NewAPIEngine creates a new OpenAI API engine. -func NewAPIEngine(apiKey string, opts ...APIEngineOption) *APIEngine { - e := &APIEngine{ - apiKey: apiKey, - fastModel: defaultAPIFastModel, - powerModel: defaultAPIPowerModel, - httpClient: &http.Client{Timeout: defaultAPITimeout}, - maxTokens: defaultAPIMaxTokens, - temperature: defaultAPITemperature, - verbose: false, - } - - for _, opt := range opts { - opt(e) - } - - return e -} - -// Name returns the engine identifier. -func (e *APIEngine) Name() string { - return "openai-api" -} - -// IsAvailable checks if the engine can be used. -func (e *APIEngine) IsAvailable() bool { - return e.apiKey != "" -} - -// Capabilities returns engine capabilities. -func (e *APIEngine) Capabilities() Capabilities { - return Capabilities{ - SupportsTemperature: true, - SupportsMaxTokens: true, - SupportsComplexity: true, - SupportsStreaming: true, - MaxContextLength: 128000, - Models: []string{e.fastModel, e.powerModel}, - } -} - -// Execute sends the request via OpenAI API. -func (e *APIEngine) Execute(ctx context.Context, req *Request) (string, error) { - if e.apiKey == "" { - return "", fmt.Errorf("OpenAI API key not configured") - } - - // Select model based on complexity - model := e.fastModel - var reasoningEffort string - - switch req.Complexity { - case ComplexityMinimal: - model = e.fastModel - reasoningEffort = "minimal" - case ComplexityLow: - model = e.fastModel - case ComplexityMedium: - model = e.powerModel - reasoningEffort = "low" - case ComplexityHigh: - model = e.powerModel - reasoningEffort = "medium" - } - - // Build request body - maxTokens := req.MaxTokens - if maxTokens == 0 { - maxTokens = e.maxTokens - } - - temperature := req.Temperature - if temperature == 0 { - temperature = e.temperature - } - - apiReq := openAIAPIRequest{ - Model: model, - Messages: []openAIAPIMessage{ - {Role: "user", Content: req.CombinedPrompt()}, - }, - MaxTokens: maxTokens, - Temperature: temperature, - } - - if reasoningEffort != "" { - apiReq.ReasoningEffort = reasoningEffort - } - - jsonData, err := json.Marshal(apiReq) - if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, openAIAPIURL, bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) - - if e.verbose { - fmt.Fprintf(os.Stderr, "OpenAI API request:\n Model: %s\n Complexity: %s\n Prompt length: %d chars\n", - model, req.Complexity, len(req.UserPrompt)) - } - - resp, err := e.httpClient.Do(httpReq) - if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp openAIAPIResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return "", fmt.Errorf("failed to unmarshal response: %w", err) - } - - if apiResp.Error != nil { - return "", fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)", - apiResp.Error.Message, apiResp.Error.Type, apiResp.Error.Code) - } - - if len(apiResp.Choices) == 0 { - return "", fmt.Errorf("no choices in response") - } - - content := apiResp.Choices[0].Message.Content - - if e.verbose { - fmt.Fprintf(os.Stderr, "OpenAI API response:\n Tokens: %d\n Content length: %d chars\n", - apiResp.Usage.TotalTokens, len(content)) - } - - return content, nil -} - -// SetVerbose sets verbose mode. -func (e *APIEngine) SetVerbose(verbose bool) { - e.verbose = verbose -} - -// openAIAPIRequest represents the OpenAI API request structure. -type openAIAPIRequest struct { - Model string `json:"model"` - Messages []openAIAPIMessage `json:"messages"` - MaxTokens int `json:"max_completion_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` -} - -// openAIAPIMessage represents a message in the OpenAI API request. -type openAIAPIMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// openAIAPIResponse represents the OpenAI API response structure. -type openAIAPIResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - Error *struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code"` - } `json:"error,omitempty"` -} diff --git a/internal/llm/engine/api_test.go b/internal/llm/engine/api_test.go deleted file mode 100644 index 121df05..0000000 --- a/internal/llm/engine/api_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package engine - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNewAPIEngine(t *testing.T) { - t.Run("with api key", func(t *testing.T) { - engine := NewAPIEngine("sk-test-key") - assert.NotNil(t, engine) - assert.Equal(t, "openai-api", engine.Name()) - assert.True(t, engine.IsAvailable()) - }) - - t.Run("without api key", func(t *testing.T) { - engine := NewAPIEngine("") - assert.NotNil(t, engine) - assert.False(t, engine.IsAvailable()) - }) - - t.Run("with options", func(t *testing.T) { - engine := NewAPIEngine("sk-test-key", - WithAPIFastModel("gpt-4o"), - WithAPIPowerModel("o3-mini"), - WithAPIVerbose(true), - ) - assert.NotNil(t, engine) - caps := engine.Capabilities() - assert.Contains(t, caps.Models, "gpt-4o") - assert.Contains(t, caps.Models, "o3-mini") - }) -} - -func TestAPIEngine_Capabilities(t *testing.T) { - engine := NewAPIEngine("sk-test-key") - caps := engine.Capabilities() - - assert.True(t, caps.SupportsTemperature) - assert.True(t, caps.SupportsMaxTokens) - assert.True(t, caps.SupportsComplexity) - assert.True(t, caps.SupportsStreaming) - assert.Equal(t, 128000, caps.MaxContextLength) - assert.Len(t, caps.Models, 2) -} - -func TestAPIEngine_Name(t *testing.T) { - engine := NewAPIEngine("sk-test-key") - assert.Equal(t, "openai-api", engine.Name()) -} - -func TestAPIEngine_IsAvailable(t *testing.T) { - t.Run("available with key", func(t *testing.T) { - engine := NewAPIEngine("sk-test-key") - assert.True(t, engine.IsAvailable()) - }) - - t.Run("not available without key", func(t *testing.T) { - engine := NewAPIEngine("") - assert.False(t, engine.IsAvailable()) - }) -} - diff --git a/internal/llm/engine/cli.go b/internal/llm/engine/cli.go deleted file mode 100644 index 1d605bc..0000000 --- a/internal/llm/engine/cli.go +++ /dev/null @@ -1,224 +0,0 @@ -package engine - -import ( - "bytes" - "context" - "fmt" - "os" - "os/exec" - "time" - - "github.com/DevSymphony/sym-cli/internal/llm/engine/cliprovider" -) - -const ( - defaultCLITimeout = 120 * time.Second -) - -// Re-export CLI provider types for backward compatibility. -type CLIProviderType = cliprovider.Type - -const ( - // ProviderClaude is the Claude CLI provider. - ProviderClaude CLIProviderType = cliprovider.TypeClaude - // ProviderGemini is the Gemini CLI provider. - ProviderGemini CLIProviderType = cliprovider.TypeGemini -) - -// CLIProvider is an alias to cliprovider.Provider. -type CLIProvider = cliprovider.Provider - -// CLIInfo is an alias to cliprovider.Info. -type CLIInfo = cliprovider.Info - -// SupportedProviders returns all supported CLI providers. -func SupportedProviders() map[CLIProviderType]*CLIProvider { - return cliprovider.Supported() -} - -// GetProvider returns the provider for the given type. -func GetProvider(providerType CLIProviderType) (*CLIProvider, error) { - return cliprovider.Get(providerType) -} - -// DetectAvailableCLIs scans for installed CLI tools. -func DetectAvailableCLIs() []CLIInfo { - return cliprovider.Detect() -} - -// GetProviderByCommand finds a provider by its command name. -func GetProviderByCommand(command string) (*CLIProvider, error) { - return cliprovider.GetByCommand(command) -} - -// CLIEngine implements LLMEngine interface for CLI-based LLM tools. -type CLIEngine struct { - provider *cliprovider.Provider - model string - largeModel string - timeout time.Duration - verbose bool - customPath string -} - -// CLIEngineOption is a functional option for CLIEngine. -type CLIEngineOption func(*CLIEngine) - -// WithCLIModel sets the default model. -func WithCLIModel(model string) CLIEngineOption { - return func(e *CLIEngine) { e.model = model } -} - -// WithCLILargeModel sets the model for high complexity tasks. -func WithCLILargeModel(model string) CLIEngineOption { - return func(e *CLIEngine) { e.largeModel = model } -} - -// WithCLITimeout sets the execution timeout. -func WithCLITimeout(timeout time.Duration) CLIEngineOption { - return func(e *CLIEngine) { e.timeout = timeout } -} - -// WithCLIVerbose enables verbose logging. -func WithCLIVerbose(verbose bool) CLIEngineOption { - return func(e *CLIEngine) { e.verbose = verbose } -} - -// WithCLIPath sets a custom path to the CLI executable. -func WithCLIPath(path string) CLIEngineOption { - return func(e *CLIEngine) { e.customPath = path } -} - -// NewCLIEngine creates a new CLI engine for the given provider. -func NewCLIEngine(providerType CLIProviderType, opts ...CLIEngineOption) (*CLIEngine, error) { - provider, err := cliprovider.Get(providerType) - if err != nil { - return nil, err - } - - e := &CLIEngine{ - provider: provider, - model: provider.DefaultModel, - largeModel: provider.LargeModel, - timeout: defaultCLITimeout, - verbose: false, - } - - for _, opt := range opts { - opt(e) - } - - return e, nil -} - -// Name returns the engine identifier. -func (e *CLIEngine) Name() string { - return fmt.Sprintf("cli-%s", e.provider.Type) -} - -// IsAvailable checks if the engine can be used. -func (e *CLIEngine) IsAvailable() bool { - cmdPath := e.getCommandPath() - _, err := exec.LookPath(cmdPath) - return err == nil -} - -// Capabilities returns engine capabilities. -func (e *CLIEngine) Capabilities() Capabilities { - models := []string{e.model} - if e.largeModel != "" && e.largeModel != e.model { - models = append(models, e.largeModel) - } - - return Capabilities{ - SupportsTemperature: e.provider.SupportsTemperature, - SupportsMaxTokens: e.provider.SupportsMaxTokens, - SupportsComplexity: e.largeModel != "", - SupportsStreaming: false, - MaxContextLength: 0, - Models: models, - } -} - -// Execute sends the request via CLI. -func (e *CLIEngine) Execute(ctx context.Context, req *Request) (string, error) { - model := e.model - if req.Complexity >= ComplexityHigh && e.largeModel != "" { - model = e.largeModel - } - - prompt := req.CombinedPrompt() - args := e.provider.BuildArgs(model, prompt) - args = e.appendOptionalFlags(args, req) - - if e.verbose { - fmt.Fprintf(os.Stderr, "CLI Engine (%s) request:\n Model: %s\n Complexity: %s\n Prompt length: %d chars\n", - e.provider.Type, model, req.Complexity, len(prompt)) - } - - cmdCtx, cancel := context.WithTimeout(ctx, e.timeout) - defer cancel() - - cmdPath := e.getCommandPath() - cmd := exec.CommandContext(cmdCtx, cmdPath, args...) - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - if err != nil { - if cmdCtx.Err() == context.DeadlineExceeded { - return "", fmt.Errorf("CLI command timed out after %v", e.timeout) - } - return "", fmt.Errorf("CLI command failed: %w\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) - } - - response, err := e.provider.ParseResponse(stdout.Bytes()) - if err != nil { - return "", fmt.Errorf("failed to parse CLI response: %w", err) - } - - if e.verbose { - fmt.Fprintf(os.Stderr, "CLI Engine (%s) response:\n Content length: %d chars\n", - e.provider.Type, len(response)) - } - - return response, nil -} - -// getCommandPath returns the path to the CLI executable. -func (e *CLIEngine) getCommandPath() string { - if e.customPath != "" { - return e.customPath - } - return e.provider.Command -} - -// appendOptionalFlags adds optional flags based on request parameters. -func (e *CLIEngine) appendOptionalFlags(args []string, req *Request) []string { - if e.provider.SupportsMaxTokens && e.provider.MaxTokensFlag != "" && req.MaxTokens > 0 { - args = append(args, e.provider.MaxTokensFlag, fmt.Sprintf("%d", req.MaxTokens)) - } - - if e.provider.SupportsTemperature && e.provider.TemperatureFlag != "" && req.Temperature > 0 { - args = append(args, e.provider.TemperatureFlag, fmt.Sprintf("%.2f", req.Temperature)) - } - - return args -} - -// GetProvider returns the underlying provider. -func (e *CLIEngine) GetProvider() *CLIProvider { - return e.provider -} - -// GetModel returns the current model. -func (e *CLIEngine) GetModel() string { - return e.model -} - -// SetVerbose sets verbose mode. -func (e *CLIEngine) SetVerbose(verbose bool) { - e.verbose = verbose -} diff --git a/internal/llm/engine/cli_test.go b/internal/llm/engine/cli_test.go deleted file mode 100644 index bfe3dfe..0000000 --- a/internal/llm/engine/cli_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package engine - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewCLIEngine(t *testing.T) { - t.Run("valid provider", func(t *testing.T) { - engine, err := NewCLIEngine(ProviderClaude) - require.NoError(t, err) - assert.NotNil(t, engine) - assert.Equal(t, "cli-claude", engine.Name()) - }) - - t.Run("with options", func(t *testing.T) { - engine, err := NewCLIEngine( - ProviderClaude, - WithCLIModel("custom-model"), - WithCLILargeModel("large-model"), - WithCLIVerbose(true), - ) - require.NoError(t, err) - assert.Equal(t, "custom-model", engine.GetModel()) - }) - - t.Run("invalid provider", func(t *testing.T) { - _, err := NewCLIEngine(CLIProviderType("invalid")) - assert.Error(t, err) - }) -} - -func TestCLIEngine_Capabilities(t *testing.T) { - engine, err := NewCLIEngine(ProviderClaude) - require.NoError(t, err) - - caps := engine.Capabilities() - - assert.False(t, caps.SupportsMaxTokens) - assert.False(t, caps.SupportsStreaming) - assert.True(t, caps.SupportsComplexity) // Has LargeModel - assert.NotEmpty(t, caps.Models) -} - -func TestDetectAvailableCLIs(t *testing.T) { - clis := DetectAvailableCLIs() - - // Should return info for all supported providers - assert.Len(t, clis, 2) - - // Each CLI should have provider and name set - for _, cli := range clis { - assert.NotEmpty(t, cli.Provider) - assert.NotEmpty(t, cli.Name) - } -} - -func TestGetProviderByCommand(t *testing.T) { - t.Run("claude command", func(t *testing.T) { - provider, err := GetProviderByCommand("claude") - require.NoError(t, err) - assert.Equal(t, ProviderClaude, provider.Type) - }) - - t.Run("gemini command", func(t *testing.T) { - provider, err := GetProviderByCommand("gemini") - require.NoError(t, err) - assert.Equal(t, ProviderGemini, provider.Type) - }) - - t.Run("unknown command", func(t *testing.T) { - _, err := GetProviderByCommand("unknown") - assert.Error(t, err) - }) -} diff --git a/internal/llm/engine/cliprovider/claude.go b/internal/llm/engine/cliprovider/claude.go deleted file mode 100644 index 4961c98..0000000 --- a/internal/llm/engine/cliprovider/claude.go +++ /dev/null @@ -1,30 +0,0 @@ -package cliprovider - -import "strings" - -func newClaudeProvider() *Provider { - return &Provider{ - Type: TypeClaude, - DisplayName: "Claude CLI", - Command: "claude", - DefaultModel: "claude-haiku-4-5-20251001", - LargeModel: "claude-sonnet-4-5-20250929", - BuildArgs: func(model string, prompt string) []string { - args := []string{ - "-p", prompt, - "--output-format", "text", - } - if model != "" { - args = append(args, "--model", model) - } - return args - }, - ParseResponse: func(output []byte) (string, error) { - return strings.TrimSpace(string(output)), nil - }, - SupportsMaxTokens: false, - MaxTokensFlag: "", - SupportsTemperature: false, - TemperatureFlag: "", - } -} diff --git a/internal/llm/engine/cliprovider/gemini.go b/internal/llm/engine/cliprovider/gemini.go deleted file mode 100644 index f299a8e..0000000 --- a/internal/llm/engine/cliprovider/gemini.go +++ /dev/null @@ -1,27 +0,0 @@ -package cliprovider - -import "strings" - -func newGeminiProvider() *Provider { - return &Provider{ - Type: TypeGemini, - DisplayName: "Gemini CLI", - Command: "gemini", - DefaultModel: "gemini-2.0-flash", - LargeModel: "gemini-2.5-pro-preview-06-05", - BuildArgs: func(model string, prompt string) []string { - return []string{ - "prompt", - "-m", model, - prompt, - } - }, - ParseResponse: func(output []byte) (string, error) { - return strings.TrimSpace(string(output)), nil - }, - SupportsMaxTokens: true, - MaxTokensFlag: "--max-tokens", - SupportsTemperature: true, - TemperatureFlag: "--temperature", - } -} diff --git a/internal/llm/engine/cliprovider/provider.go b/internal/llm/engine/cliprovider/provider.go deleted file mode 100644 index 7922d3f..0000000 --- a/internal/llm/engine/cliprovider/provider.go +++ /dev/null @@ -1,141 +0,0 @@ -package cliprovider - -import ( - "fmt" - "os/exec" - "strings" -) - -// Type represents supported CLI provider types. -type Type string - -const ( - // TypeClaude is the Claude CLI provider. - TypeClaude Type = "claude" - // TypeGemini is the Gemini CLI provider. - TypeGemini Type = "gemini" -) - -// IsValid checks if the provider type is valid. -func (t Type) IsValid() bool { - switch t { - case TypeClaude, TypeGemini: - return true - default: - return false - } -} - -// Provider defines how to interact with a specific CLI tool. -type Provider struct { - // Type is the provider identifier. - Type Type - - // DisplayName is the human-readable name. - DisplayName string - - // Command is the executable name or path. - Command string - - // DefaultModel is the default model to use. - DefaultModel string - - // LargeModel is the model for high complexity tasks (optional). - LargeModel string - - // BuildArgs constructs CLI arguments for the given request. - BuildArgs func(model string, prompt string) []string - - // ParseResponse extracts text from CLI output. - ParseResponse func(output []byte) (string, error) - - // SupportsMaxTokens indicates if --max-tokens or similar is supported. - SupportsMaxTokens bool - - // MaxTokensFlag is the flag name for max tokens (e.g., "--max-tokens"). - MaxTokensFlag string - - // SupportsTemperature indicates if temperature is supported. - SupportsTemperature bool - - // TemperatureFlag is the flag name for temperature. - TemperatureFlag string -} - -// Info represents detected CLI information. -type Info struct { - Provider Type - Name string - Path string - Version string - Available bool -} - -// Supported returns all supported CLI providers. -func Supported() map[Type]*Provider { - return map[Type]*Provider{ - TypeClaude: newClaudeProvider(), - TypeGemini: newGeminiProvider(), - } -} - -// Get returns the provider for the given type. -func Get(providerType Type) (*Provider, error) { - providers := Supported() - provider, ok := providers[providerType] - if !ok { - return nil, fmt.Errorf("unsupported CLI provider: %s", providerType) - } - return provider, nil -} - -// Detect scans for installed CLI tools. -func Detect() []Info { - var results []Info - - providers := Supported() - for providerType, provider := range providers { - info := Info{ - Provider: providerType, - Name: provider.DisplayName, - Available: false, - } - - path, err := exec.LookPath(provider.Command) - if err == nil { - info.Path = path - info.Available = true - info.Version = getProviderVersion(provider) - } - - results = append(results, info) - } - - return results -} - -// GetByCommand finds a provider by its command name. -func GetByCommand(command string) (*Provider, error) { - providers := Supported() - for _, provider := range providers { - if provider.Command == command { - return provider, nil - } - } - return nil, fmt.Errorf("no provider found for command: %s", command) -} - -func getProviderVersion(provider *Provider) string { - cmd := exec.Command(provider.Command, "--version") // #nosec G204 - output, err := cmd.Output() - if err != nil { - return "" - } - - lines := strings.Split(strings.TrimSpace(string(output)), "\n") - if len(lines) > 0 { - return strings.TrimSpace(lines[0]) - } - - return "" -} diff --git a/internal/llm/engine/cliprovider/provider_test.go b/internal/llm/engine/cliprovider/provider_test.go deleted file mode 100644 index 48e637e..0000000 --- a/internal/llm/engine/cliprovider/provider_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package cliprovider - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestType_IsValid(t *testing.T) { - tests := []struct { - name string - typ Type - want bool - }{ - {"claude", TypeClaude, true}, - {"gemini", TypeGemini, true}, - {"invalid", Type("invalid"), false}, - {"empty", Type(""), false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.typ.IsValid()) - }) - } -} - -func TestSupported(t *testing.T) { - providers := Supported() - - assert.Len(t, providers, 2) - assert.Contains(t, providers, TypeClaude) - assert.Contains(t, providers, TypeGemini) -} - -func TestGet(t *testing.T) { - t.Run("claude", func(t *testing.T) { - provider, err := Get(TypeClaude) - require.NoError(t, err) - assert.Equal(t, TypeClaude, provider.Type) - assert.Equal(t, "Claude CLI", provider.DisplayName) - assert.Equal(t, "claude", provider.Command) - }) - - t.Run("gemini", func(t *testing.T) { - provider, err := Get(TypeGemini) - require.NoError(t, err) - assert.Equal(t, TypeGemini, provider.Type) - assert.Equal(t, "Gemini CLI", provider.DisplayName) - assert.Equal(t, "gemini", provider.Command) - }) - - t.Run("invalid", func(t *testing.T) { - _, err := Get(Type("invalid")) - assert.Error(t, err) - }) -} - -func TestBuildArgs(t *testing.T) { - t.Run("claude", func(t *testing.T) { - provider := newClaudeProvider() - args := provider.BuildArgs("claude-3-opus", "Hello!") - - assert.Contains(t, args, "-p") - assert.Contains(t, args, "Hello!") - assert.Contains(t, args, "--model") - assert.Contains(t, args, "claude-3-opus") - }) - - t.Run("gemini", func(t *testing.T) { - provider := newGeminiProvider() - args := provider.BuildArgs("gemini-pro", "Hello!") - - assert.Contains(t, args, "prompt") - assert.Contains(t, args, "-m") - assert.Contains(t, args, "gemini-pro") - }) -} - -func TestParseResponse(t *testing.T) { - providers := Supported() - - for typ, provider := range providers { - t.Run(string(typ), func(t *testing.T) { - resp, err := provider.ParseResponse([]byte(" trimmed response \n")) - require.NoError(t, err) - assert.Equal(t, "trimmed response", resp) - }) - } -} - -func TestDetect(t *testing.T) { - info := Detect() - assert.Len(t, info, 2) - - for _, cli := range info { - assert.NotEmpty(t, cli.Provider) - assert.NotEmpty(t, cli.Name) - } -} - -func TestGetByCommand(t *testing.T) { - t.Run("claude", func(t *testing.T) { - provider, err := GetByCommand("claude") - require.NoError(t, err) - assert.Equal(t, TypeClaude, provider.Type) - }) - - t.Run("gemini", func(t *testing.T) { - provider, err := GetByCommand("gemini") - require.NoError(t, err) - assert.Equal(t, TypeGemini, provider.Type) - }) - - t.Run("invalid", func(t *testing.T) { - _, err := GetByCommand("unknown") - assert.Error(t, err) - }) -} diff --git a/internal/llm/engine/engine.go b/internal/llm/engine/engine.go deleted file mode 100644 index 99756cc..0000000 --- a/internal/llm/engine/engine.go +++ /dev/null @@ -1,113 +0,0 @@ -package engine - -import "context" - -// Complexity represents task complexity hint (engine-agnostic). -// This allows callers to express intent without coupling to specific engine features. -type Complexity int - -const ( - // ComplexityMinimal is for trivial lookups or boilerplate prompts. - ComplexityMinimal Complexity = iota - // ComplexityLow is for simple transformations, parsing, basic formatting. - ComplexityLow - // ComplexityMedium is for analysis, routing decisions, moderate reasoning. - ComplexityMedium - // ComplexityHigh is for complex reasoning, code generation, deep analysis. - ComplexityHigh -) - -// String returns human-readable complexity name. -func (c Complexity) String() string { - switch c { - case ComplexityMinimal: - return "minimal" - case ComplexityLow: - return "low" - case ComplexityMedium: - return "medium" - case ComplexityHigh: - return "high" - default: - return "unknown" - } -} - -// Request represents an engine-agnostic LLM request. -// All engines receive this unified request format and interpret it according to their capabilities. -type Request struct { - SystemPrompt string - UserPrompt string - MaxTokens int - Temperature float64 - Complexity Complexity -} - -// CombinedPrompt returns system and user prompts combined. -func (r *Request) CombinedPrompt() string { - if r.SystemPrompt == "" { - return r.UserPrompt - } - return r.SystemPrompt + "\n\n" + r.UserPrompt -} - -// LLMEngine is the interface for LLM execution engines. -type LLMEngine interface { - // Execute sends request and returns response text. - Execute(ctx context.Context, req *Request) (string, error) - - // Name returns engine identifier. - Name() string - - // IsAvailable checks if this engine can currently be used. - IsAvailable() bool - - // Capabilities returns what features this engine supports. - Capabilities() Capabilities -} - -// Capabilities describes what features an engine supports. -// This enables graceful degradation when features aren't available. -type Capabilities struct { - // SupportsTemperature indicates if temperature parameter is respected. - SupportsTemperature bool - - // SupportsMaxTokens indicates if max_tokens parameter is respected. - SupportsMaxTokens bool - - // SupportsComplexity indicates if complexity hint affects model selection. - SupportsComplexity bool - - // SupportsStreaming indicates if streaming responses are supported. - SupportsStreaming bool - - // MaxContextLength is the maximum input context length (0 = unknown). - MaxContextLength int - - // Models lists available models for this engine. - Models []string -} - -// Mode represents the preferred engine selection mode. -type Mode string - -const ( - // ModeAuto automatically selects the best available engine. - ModeAuto Mode = "auto" - // ModeMCP forces MCP sampling engine. - ModeMCP Mode = "mcp" - // ModeCLI forces CLI engine. - ModeCLI Mode = "cli" - // ModeAPI forces API engine. - ModeAPI Mode = "api" -) - -// IsValid checks if the engine mode is valid. -func (m Mode) IsValid() bool { - switch m { - case ModeAuto, ModeMCP, ModeCLI, ModeAPI: - return true - default: - return false - } -} diff --git a/internal/llm/engine/engine_test.go b/internal/llm/engine/engine_test.go deleted file mode 100644 index 2954b45..0000000 --- a/internal/llm/engine/engine_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package engine - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestComplexity_String(t *testing.T) { - tests := []struct { - name string - complexity Complexity - want string - }{ - {"low", ComplexityLow, "low"}, - {"medium", ComplexityMedium, "medium"}, - {"high", ComplexityHigh, "high"}, - {"unknown", Complexity(99), "unknown"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.complexity.String()) - }) - } -} - -func TestRequest_CombinedPrompt(t *testing.T) { - tests := []struct { - name string - req Request - want string - }{ - { - name: "with system and user prompt", - req: Request{ - SystemPrompt: "You are a helpful assistant.", - UserPrompt: "Hello!", - }, - want: "You are a helpful assistant.\n\nHello!", - }, - { - name: "only user prompt", - req: Request{ - SystemPrompt: "", - UserPrompt: "Hello!", - }, - want: "Hello!", - }, - { - name: "empty prompts", - req: Request{ - SystemPrompt: "", - UserPrompt: "", - }, - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.req.CombinedPrompt()) - }) - } -} - -func TestMode_IsValid(t *testing.T) { - tests := []struct { - name string - mode Mode - valid bool - }{ - {"auto", ModeAuto, true}, - {"mcp", ModeMCP, true}, - {"cli", ModeCLI, true}, - {"api", ModeAPI, true}, - {"invalid", Mode("invalid"), false}, - {"empty", Mode(""), false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.valid, tt.mode.IsValid()) - }) - } -} - -func TestCapabilities_Default(t *testing.T) { - caps := Capabilities{} - - assert.False(t, caps.SupportsTemperature) - assert.False(t, caps.SupportsMaxTokens) - assert.False(t, caps.SupportsComplexity) - assert.False(t, caps.SupportsStreaming) - assert.Equal(t, 0, caps.MaxContextLength) - assert.Nil(t, caps.Models) -} - diff --git a/internal/llm/engine/mcp.go b/internal/llm/engine/mcp.go deleted file mode 100644 index 909f231..0000000 --- a/internal/llm/engine/mcp.go +++ /dev/null @@ -1,116 +0,0 @@ -package engine - -import ( - "context" - "fmt" - "os" - - mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// MCPEngine implements LLMEngine interface for MCP sampling. -// It delegates LLM calls to the host application via MCP's CreateMessage. -type MCPEngine struct { - session *mcpsdk.ServerSession - verbose bool -} - -// MCPEngineOption is a functional option for MCPEngine. -type MCPEngineOption func(*MCPEngine) - -// WithMCPVerbose enables verbose logging. -func WithMCPVerbose(verbose bool) MCPEngineOption { - return func(e *MCPEngine) { e.verbose = verbose } -} - -// NewMCPEngine creates a new MCP sampling engine. -func NewMCPEngine(session *mcpsdk.ServerSession, opts ...MCPEngineOption) *MCPEngine { - e := &MCPEngine{ - session: session, - verbose: false, - } - - for _, opt := range opts { - opt(e) - } - - return e -} - -// Name returns the engine identifier. -func (e *MCPEngine) Name() string { - return "mcp-sampling" -} - -// IsAvailable checks if the engine can be used. -func (e *MCPEngine) IsAvailable() bool { - return e.session != nil -} - -// Capabilities returns engine capabilities. -// MCP sampling capabilities depend on the host LLM, so we're conservative here. -func (e *MCPEngine) Capabilities() Capabilities { - return Capabilities{ - SupportsTemperature: false, // Host decides - SupportsMaxTokens: true, // Passed to CreateMessage - SupportsComplexity: false, // Host decides model - SupportsStreaming: false, // Not implemented - MaxContextLength: 0, // Unknown - Models: nil, // Host decides - } -} - -// Execute sends the request via MCP sampling. -func (e *MCPEngine) Execute(ctx context.Context, req *Request) (string, error) { - if e.session == nil { - return "", fmt.Errorf("MCP session not available") - } - - if e.verbose { - fmt.Fprintf(os.Stderr, "MCP Sampling request:\n MaxTokens: %d\n Prompt length: %d chars\n", - req.MaxTokens, len(req.UserPrompt)) - } - - maxTokens := req.MaxTokens - if maxTokens == 0 { - maxTokens = defaultAPIMaxTokens - } - - result, err := e.session.CreateMessage(ctx, &mcpsdk.CreateMessageParams{ - Messages: []*mcpsdk.SamplingMessage{ - { - Role: "user", - Content: &mcpsdk.TextContent{Text: req.CombinedPrompt()}, - }, - }, - MaxTokens: int64(maxTokens), - }) - if err != nil { - return "", fmt.Errorf("MCP sampling failed: %w", err) - } - - var response string - if textContent, ok := result.Content.(*mcpsdk.TextContent); ok { - response = textContent.Text - } else { - return "", fmt.Errorf("unexpected content type from MCP sampling") - } - - if e.verbose { - fmt.Fprintf(os.Stderr, "MCP Sampling response:\n Model: %s\n Content length: %d chars\n", - result.Model, len(response)) - } - - return response, nil -} - -// GetSession returns the underlying MCP session. -func (e *MCPEngine) GetSession() *mcpsdk.ServerSession { - return e.session -} - -// SetVerbose sets verbose mode. -func (e *MCPEngine) SetVerbose(verbose bool) { - e.verbose = verbose -} - diff --git a/internal/llm/geminicli/provider.go b/internal/llm/geminicli/provider.go new file mode 100644 index 0000000..bb7fe3b --- /dev/null +++ b/internal/llm/geminicli/provider.go @@ -0,0 +1,115 @@ +// Package geminicli provides the Gemini CLI LLM provider. +package geminicli + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" + "time" + + "github.com/DevSymphony/sym-cli/internal/llm" +) + +const ( + providerName = "geminicli" + displayName = "Gemini CLI" + command = "gemini" + defaultModel = "gemini-2.5-flash" + defaultTimeout = 120 * time.Second +) + +func init() { + // Check if CLI is available + path, _ := exec.LookPath(command) + available := path != "" + + llm.RegisterProvider(providerName, newProvider, llm.ProviderInfo{ + Name: providerName, + DisplayName: displayName, + DefaultModel: defaultModel, + Available: available, + Path: path, + Models: []llm.ModelInfo{ + {ID: "gemini-2.5-flash", DisplayName: "2.5 flash", Description: "Fast and efficient", Recommended: true}, + {ID: "gemini-2.5-pro", DisplayName: "2.5 pro", Description: "Higher capability", Recommended: false}, + }, + APIKey: llm.APIKeyConfig{Required: false}, + }) +} + +// Provider implements llm.RawProvider for Gemini CLI. +type Provider struct { + model string + timeout time.Duration + verbose bool + cliPath string +} + +// Compile-time check: Provider must implement RawProvider interface +var _ llm.RawProvider = (*Provider)(nil) + +// newProvider creates a new Gemini CLI provider. +// Returns error if Gemini CLI is not installed. +func newProvider(cfg llm.Config) (llm.RawProvider, error) { + path, err := exec.LookPath(command) + if err != nil { + return nil, fmt.Errorf("gemini CLI not installed: run 'npm install -g @anthropic-ai/gemini-cli' to install") + } + + model := cfg.Model + if model == "" { + model = defaultModel + } + + return &Provider{ + model: model, + timeout: defaultTimeout, + verbose: cfg.Verbose, + cliPath: path, + }, nil +} + +func (p *Provider) Name() string { + return providerName +} + +func (p *Provider) ExecuteRaw(ctx context.Context, prompt string, format llm.ResponseFormat) (string, error) { + args := []string{"prompt", "-m", p.model, prompt} + + if p.verbose { + fmt.Fprintf(os.Stderr, "[geminicli] Model: %s, Prompt: %d chars\n", p.model, len(prompt)) + } + + cmdCtx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + cmd := exec.CommandContext(cmdCtx, p.cliPath, args...) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + if cmdCtx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("gemini CLI timed out after %v", p.timeout) + } + return "", fmt.Errorf("gemini CLI failed: %w\nstderr: %s", err, stderr.String()) + } + + response := strings.TrimSpace(stdout.String()) + + if p.verbose { + fmt.Fprintf(os.Stderr, "[geminicli] Response: %d chars\n", len(response)) + } + + return response, nil +} + +// Close is a no-op for CLI-based providers. +func (p *Provider) Close() error { + return nil +} diff --git a/internal/llm/llm.go b/internal/llm/llm.go new file mode 100644 index 0000000..3d8fb55 --- /dev/null +++ b/internal/llm/llm.go @@ -0,0 +1,73 @@ +// Package llm provides a unified interface for LLM providers. +package llm + +import "context" + +// Provider is the interface for LLM providers. +type Provider interface { + // Execute sends a prompt and returns the parsed response. + Execute(ctx context.Context, prompt string, format ResponseFormat) (string, error) + // Name returns the provider name. + Name() string + // Close releases any resources held by the provider. + Close() error +} + +// RawProvider is the interface for provider implementations. +// Provider implementations should implement this interface. +// The registry will automatically wrap RawProvider with parsing logic. +type RawProvider interface { + // ExecuteRaw sends a prompt and returns the raw (unparsed) response. + ExecuteRaw(ctx context.Context, prompt string, format ResponseFormat) (string, error) + // Name returns the provider name. + Name() string + // Close releases any resources held by the provider. + Close() error +} + +// ResponseFormat specifies the expected response format. +type ResponseFormat string + +const ( + Text ResponseFormat = "text" + JSON ResponseFormat = "json" + XML ResponseFormat = "xml" +) + +// String returns the string representation of the format. +func (f ResponseFormat) String() string { + return string(f) +} + +// Config holds LLM provider configuration. +type Config struct { + Provider string // "claudecode", "geminicli", "openaiapi" + Model string // Model name (optional, uses provider default) + Verbose bool // Enable verbose logging +} + +// ModelInfo describes a model available for a provider. +type ModelInfo struct { + ID string // Internal model identifier (e.g., "sonnet", "gpt-4o-mini") + DisplayName string // Human-readable name for UI + Description string // Short description + Recommended bool // Default/recommended model flag +} + +// APIKeyConfig describes API key requirements for a provider. +type APIKeyConfig struct { + Required bool // Whether this provider requires an API key + EnvVarName string // Environment variable name (e.g., "OPENAI_API_KEY") + Prefix string // Expected prefix for validation (e.g., "sk-") +} + +// ProviderInfo contains provider metadata. +type ProviderInfo struct { + Name string + DisplayName string + DefaultModel string + Available bool + Path string // CLI path or empty for API providers + Models []ModelInfo // Available models for this provider + APIKey APIKeyConfig // API key configuration +} diff --git a/internal/llm/llm_test.go b/internal/llm/llm_test.go new file mode 100644 index 0000000..3598231 --- /dev/null +++ b/internal/llm/llm_test.go @@ -0,0 +1,166 @@ +package llm + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRawProvider is a test provider for unit tests. +type mockRawProvider struct { + name string + response string + err error +} + +func (m *mockRawProvider) ExecuteRaw(ctx context.Context, prompt string, format ResponseFormat) (string, error) { + if m.err != nil { + return "", m.err + } + return m.response, nil +} + +func (m *mockRawProvider) Name() string { + return m.name +} + +func (m *mockRawProvider) Close() error { + return nil +} + +func init() { + // Register a test provider + RegisterProvider("test-provider", func(cfg Config) (RawProvider, error) { + return &mockRawProvider{ + name: "test-provider", + response: "test response", + }, nil + }, ProviderInfo{ + Name: "test-provider", + DisplayName: "Test Provider", + DefaultModel: "test-model", + Available: true, + }) +} + +func TestNew(t *testing.T) { + t.Run("creates provider from config", func(t *testing.T) { + provider, err := New(Config{Provider: "test-provider"}) + require.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, "test-provider", provider.Name()) + }) + + t.Run("returns error for unknown provider", func(t *testing.T) { + _, err := New(Config{Provider: "unknown-provider"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown provider") + }) +} + +func TestProvider_Execute(t *testing.T) { + provider, err := New(Config{Provider: "test-provider"}) + require.NoError(t, err) + + t.Run("executes with text format", func(t *testing.T) { + result, err := provider.Execute(context.Background(), "test prompt", Text) + require.NoError(t, err) + assert.Equal(t, "test response", result) + }) +} + +func TestConfigValidate(t *testing.T) { + t.Run("returns error when provider is empty", func(t *testing.T) { + cfg := Config{} + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "provider is required") + }) + + t.Run("valid config with CLI provider", func(t *testing.T) { + cfg := Config{Provider: "claudecode"} + err := cfg.Validate() + assert.NoError(t, err) + }) + + t.Run("valid config with openaiapi", func(t *testing.T) { + // Note: API key validation is now handled by the provider itself + cfg := Config{Provider: "openaiapi"} + err := cfg.Validate() + assert.NoError(t, err) + }) +} + +func TestLoadConfig(t *testing.T) { + t.Run("returns config from LoadConfigFromDir", func(t *testing.T) { + // LoadConfig is a simple wrapper around LoadConfigFromDir + cfg := LoadConfig() + // Without .sym/config.json, it should return empty config + // We just verify it doesn't panic + _ = cfg.Provider + _ = cfg.Model + }) +} + +func TestLoadConfigFromDir(t *testing.T) { + t.Run("returns empty config when no config found", func(t *testing.T) { + tmpDir := t.TempDir() + cfg := LoadConfigFromDir(tmpDir) + assert.Empty(t, cfg.Provider) + assert.Empty(t, cfg.Model) + }) +} + +func TestGetProviderInfo(t *testing.T) { + t.Run("returns info for registered provider", func(t *testing.T) { + info := GetProviderInfo("test-provider") + require.NotNil(t, info) + assert.Equal(t, "test-provider", info.Name) + assert.Equal(t, "Test Provider", info.DisplayName) + }) + + t.Run("returns nil for unknown provider", func(t *testing.T) { + info := GetProviderInfo("unknown") + assert.Nil(t, info) + }) +} + +func TestListProviders(t *testing.T) { + providers := ListProviders() + assert.NotEmpty(t, providers) + + // Should include our test provider + var found bool + for _, p := range providers { + if p.Name == "test-provider" { + found = true + break + } + } + assert.True(t, found, "test-provider should be in list") +} + +func Test_parse(t *testing.T) { + t.Run("parses JSON from response", func(t *testing.T) { + response := `Here is the result: {"key": "value"}` + result, err := parse(response, JSON) + require.NoError(t, err) + assert.Equal(t, `{"key": "value"}`, result) + }) + + t.Run("parses XML from response", func(t *testing.T) { + response := `Here is XML: value` + result, err := parse(response, XML) + require.NoError(t, err) + assert.Equal(t, `value`, result) + }) + + t.Run("returns text as-is", func(t *testing.T) { + response := "Just plain text" + result, err := parse(response, Text) + require.NoError(t, err) + assert.Equal(t, response, result) + }) +} diff --git a/internal/llm/openaiapi/provider.go b/internal/llm/openaiapi/provider.go new file mode 100644 index 0000000..788cf1c --- /dev/null +++ b/internal/llm/openaiapi/provider.go @@ -0,0 +1,221 @@ +// Package openaiapi provides the OpenAI API LLM provider. +package openaiapi + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/DevSymphony/sym-cli/internal/envutil" + "github.com/DevSymphony/sym-cli/internal/llm" +) + +const ( + providerName = "openaiapi" + displayName = "OpenAI API" + apiURL = "https://api.openai.com/v1/chat/completions" + defaultModel = "gpt-4o-mini" + defaultTimeout = 60 * time.Second + defaultMaxTokens = 1000 + defaultTemperature = 1.0 +) + +// ErrAPIKeyRequired is returned when OpenAI API key is not provided. +var ErrAPIKeyRequired = errors.New("openaiapi: API key is required (set OPENAI_API_KEY environment variable)") + +func init() { + // OpenAI availability depends on API key (from env vars or .sym/.env) + llm.RegisterProvider(providerName, newProvider, llm.ProviderInfo{ + Name: providerName, + DisplayName: displayName, + DefaultModel: defaultModel, + Available: envutil.GetAPIKey("OPENAI_API_KEY") != "", + Path: "", + Models: []llm.ModelInfo{ + {ID: "gpt-4o-mini", DisplayName: "gpt-4o-mini", Description: "Fast and efficient", Recommended: true}, + {ID: "gpt-5-mini", DisplayName: "gpt-5-mini", Description: "Next generation model", Recommended: false}, + }, + APIKey: llm.APIKeyConfig{ + Required: true, + EnvVarName: "OPENAI_API_KEY", + Prefix: "sk-", + }, + }) +} + +// Provider implements llm.RawProvider for OpenAI API. +type Provider struct { + apiKey string + model string + httpClient *http.Client + maxTokens int + temperature float64 + verbose bool +} + +// Compile-time check: Provider must implement RawProvider interface +var _ llm.RawProvider = (*Provider)(nil) + +// newProvider creates a new OpenAI API provider. +// Returns ErrAPIKeyRequired if API key is not provided. +func newProvider(cfg llm.Config) (llm.RawProvider, error) { + // Provider handles its own API key loading from env vars and .sym/.env + apiKey := envutil.GetAPIKey("OPENAI_API_KEY") + if apiKey == "" { + return nil, ErrAPIKeyRequired + } + + model := cfg.Model + if model == "" { + model = defaultModel + } + + return &Provider{ + apiKey: apiKey, + model: model, + httpClient: &http.Client{Timeout: defaultTimeout}, + maxTokens: defaultMaxTokens, + temperature: defaultTemperature, + verbose: cfg.Verbose, + }, nil +} + +func (p *Provider) Name() string { + return providerName +} + +func (p *Provider) ExecuteRaw(ctx context.Context, prompt string, format llm.ResponseFormat) (string, error) { + apiReq := apiRequest{ + Model: p.model, + Messages: []apiMessage{ + {Role: "user", Content: prompt}, + }, + } + + // Model-based parameter switching: + // - Reasoning models (gpt-5, o1, o3, o4): use max_completion_tokens, reasoning_effort + // - Standard models (gpt-4o): use max_tokens, temperature + if p.isReasoningModel() { + apiReq.MaxCompletionTokens = p.maxTokens + apiReq.ReasoningEffort = "medium" + } else { + apiReq.MaxTokens = p.maxTokens + apiReq.Temperature = p.temperature + } + + jsonData, err := json.Marshal(apiReq) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + if p.verbose { + fmt.Fprintf(os.Stderr, "[openaiapi] Model: %s, Prompt: %d chars\n", p.model, len(prompt)) + } + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(body)) + } + + var apiResp apiResponse + if err := json.Unmarshal(body, &apiResp); err != nil { + return "", fmt.Errorf("failed to unmarshal response: %w", err) + } + + if apiResp.Error != nil { + return "", fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)", + apiResp.Error.Message, apiResp.Error.Type, apiResp.Error.Code) + } + + if len(apiResp.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + content := apiResp.Choices[0].Message.Content + + if p.verbose { + fmt.Fprintf(os.Stderr, "[openaiapi] Response: %d chars, Tokens: %d\n", len(content), apiResp.Usage.TotalTokens) + } + + return content, nil +} + +// isReasoningModel returns true if the model is a reasoning model (gpt-5, o1, o3, o4). +func (p *Provider) isReasoningModel() bool { + return strings.HasPrefix(p.model, "gpt-5") || + strings.HasPrefix(p.model, "o1") || + strings.HasPrefix(p.model, "o3") || + strings.HasPrefix(p.model, "o4") +} + +type apiRequest struct { + Model string `json:"model"` + Messages []apiMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` +} + +type apiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type apiResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error,omitempty"` +} + +// Close releases HTTP client resources. +func (p *Provider) Close() error { + if p.httpClient != nil { + p.httpClient.CloseIdleConnections() + } + return nil +} diff --git a/internal/llm/parser.go b/internal/llm/parser.go new file mode 100644 index 0000000..a9e53f9 --- /dev/null +++ b/internal/llm/parser.go @@ -0,0 +1,250 @@ +package llm + +import ( + "encoding/json" + "encoding/xml" + "errors" + "regexp" + "strings" +) + +// internalFormat is the internal response format type used by the parser. +type internalFormat int + +const ( + ResponseFormatText internalFormat = iota + ResponseFormatJSON + ResponseFormatXML +) + +func (f internalFormat) String() string { + switch f { + case ResponseFormatJSON: + return "json" + case ResponseFormatXML: + return "xml" + default: + return "text" + } +} + +// ParseOptions configures response parsing behavior. +type ParseOptions struct { + Format internalFormat + StrictMode bool // If true, return error when format not found +} + +var ( + ErrNoJSONFound = errors.New("no valid JSON found in response") + ErrNoXMLFound = errors.New("no valid XML found in response") +) + +// ParseResponse extracts structured content from LLM responses. +// Handles cases where LLM adds preamble text like "I need to analyze..." +func ParseResponse(response string, opts ParseOptions) (string, error) { + switch opts.Format { + case ResponseFormatJSON: + return extractJSON(response, opts.StrictMode) + case ResponseFormatXML: + return extractXML(response, opts.StrictMode) + default: + return response, nil + } +} + +// extractJSON finds and extracts JSON content from response. +func extractJSON(response string, strict bool) (string, error) { + // Strategy 1: Look for code block with json marker + if jsonBlock := extractCodeBlock(response, "json"); jsonBlock != "" { + if isValidJSON(jsonBlock) { + return jsonBlock, nil + } + } + + // Strategy 2: Find outermost { } or [ ] + if jsonStr := findJSONBoundaries(response); jsonStr != "" { + if isValidJSON(jsonStr) { + return jsonStr, nil + } + } + + // Strategy 3: Try entire response as JSON + trimmed := strings.TrimSpace(response) + if isValidJSON(trimmed) { + return trimmed, nil + } + + if strict { + return "", ErrNoJSONFound + } + return response, nil +} + +// extractXML finds and extracts XML content from response. +func extractXML(response string, strict bool) (string, error) { + // Strategy 1: Look for code block with xml marker + if xmlBlock := extractCodeBlock(response, "xml"); xmlBlock != "" { + if isValidXML(xmlBlock) { + return xmlBlock, nil + } + } + + // Strategy 2: Find tag + if xmlStr := findXMLBoundaries(response); xmlStr != "" { + if isValidXML(xmlStr) { + return xmlStr, nil + } + } + + // Strategy 3: Try entire response as XML + trimmed := strings.TrimSpace(response) + if isValidXML(trimmed) { + return trimmed, nil + } + + if strict { + return "", ErrNoXMLFound + } + return response, nil +} + +// extractCodeBlock matches ```lang ... ``` code blocks. +func extractCodeBlock(response, lang string) string { + patternStr := "(?s)```" + regexp.QuoteMeta(lang) + "\\s*\\n?(.*?)```" + pattern := regexp.MustCompile(patternStr) + matches := pattern.FindStringSubmatch(response) + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +// findJSONBoundaries finds first { or [ and matches to corresponding } or ]. +func findJSONBoundaries(s string) string { + start := -1 + var startChar byte + + for i := 0; i < len(s); i++ { + if s[i] == '{' || s[i] == '[' { + start = i + startChar = s[i] + break + } + } + + if start == -1 { + return "" + } + + var endChar byte + if startChar == '{' { + endChar = '}' + } else { + endChar = ']' + } + + // Find matching end (handle nesting) + depth := 0 + inString := false + escaped := false + + for i := start; i < len(s); i++ { + c := s[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' && inString { + escaped = true + continue + } + + if c == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + if c == startChar { + depth++ + } else if c == endChar { + depth-- + if depth == 0 { + return s[start : i+1] + } + } + } + + return "" +} + +// findXMLBoundaries finds . +func findXMLBoundaries(s string) string { + // Look for ") + if end <= start { + return "" + } + + return s[start : end+1] +} + +func isValidJSON(s string) bool { + var js interface{} + return json.Unmarshal([]byte(s), &js) == nil +} + +func isValidXML(s string) bool { + // Check if string starts with < (basic XML requirement) + trimmed := strings.TrimSpace(s) + if len(trimmed) == 0 || trimmed[0] != '<' { + return false + } + + // Simple XML validation using decoder + decoder := xml.NewDecoder(strings.NewReader(s)) + tokenCount := 0 + for { + _, err := decoder.Token() + if err != nil { + // Valid XML must have at least one token + return err.Error() == "EOF" && tokenCount > 0 + } + tokenCount++ + } +} + +// parse extracts structured content from LLM responses based on format. +func parse(response string, format ResponseFormat) (string, error) { + opts := ParseOptions{ + Format: toInternalFormat(format), + StrictMode: true, + } + return ParseResponse(response, opts) +} + +// toInternalFormat converts ResponseFormat to internal format for parser. +func toInternalFormat(f ResponseFormat) internalFormat { + switch f { + case JSON: + return ResponseFormatJSON + case XML: + return ResponseFormatXML + default: + return ResponseFormatText + } +} diff --git a/internal/llm/parser_test.go b/internal/llm/parser_test.go new file mode 100644 index 0000000..d25bfdf --- /dev/null +++ b/internal/llm/parser_test.go @@ -0,0 +1,314 @@ +package llm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInternalFormatString(t *testing.T) { + tests := []struct { + format internalFormat + expected string + }{ + {ResponseFormatText, "text"}, + {ResponseFormatJSON, "json"}, + {ResponseFormatXML, "xml"}, + {internalFormat(99), "text"}, // unknown defaults to text + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.format.String()) + }) + } +} + +func TestResponseFormatString(t *testing.T) { + assert.Equal(t, "text", Text.String()) + assert.Equal(t, "json", JSON.String()) + assert.Equal(t, "xml", XML.String()) +} + +func TestParseResponseText(t *testing.T) { + t.Run("returns response as-is for text format", func(t *testing.T) { + response := "This is plain text response." + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatText}) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) +} + +func TestParseResponseJSON(t *testing.T) { + t.Run("extracts JSON from code block", func(t *testing.T) { + response := "Here is the result:\n```json\n{\"key\": \"value\"}\n```" + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, `{"key": "value"}`, result) + }) + + t.Run("extracts JSON object from mixed content", func(t *testing.T) { + response := `I analyzed the code and found: {"result": true, "count": 5} Hope this helps!` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, `{"result": true, "count": 5}`, result) + }) + + t.Run("extracts JSON array from mixed content", func(t *testing.T) { + response := `The linters are: ["eslint", "prettier"]` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, `["eslint", "prettier"]`, result) + }) + + t.Run("handles nested JSON objects", func(t *testing.T) { + response := `{"outer": {"inner": {"deep": "value"}}}` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) + + t.Run("handles JSON with escaped characters", func(t *testing.T) { + response := `{"message": "Hello \"World\"", "path": "C:\\Users"}` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) + + t.Run("returns original on no JSON in non-strict mode", func(t *testing.T) { + response := "No JSON here, just text." + result, err := ParseResponse(response, ParseOptions{ + Format: ResponseFormatJSON, + StrictMode: false, + }) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) + + t.Run("returns error on no JSON in strict mode", func(t *testing.T) { + response := "No JSON here, just text." + _, err := ParseResponse(response, ParseOptions{ + Format: ResponseFormatJSON, + StrictMode: true, + }) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoJSONFound) + }) + + t.Run("handles empty JSON object", func(t *testing.T) { + response := `{}` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, `{}`, result) + }) + + t.Run("handles empty JSON array", func(t *testing.T) { + response := `[]` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatJSON}) + + require.NoError(t, err) + assert.Equal(t, `[]`, result) + }) +} + +func TestParseResponseXML(t *testing.T) { + t.Run("extracts XML from code block", func(t *testing.T) { + response := "Here is the XML:\n```xml\nvalue\n```" + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatXML}) + + require.NoError(t, err) + assert.Equal(t, `value`, result) + }) + + t.Run("extracts XML with declaration", func(t *testing.T) { + response := `Some preamble. test Some epilogue.` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatXML}) + + require.NoError(t, err) + assert.Contains(t, result, ``) + assert.Contains(t, result, ``) + }) + + t.Run("extracts XML without declaration", func(t *testing.T) { + response := `Analysis complete: value` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatXML}) + + require.NoError(t, err) + assert.Equal(t, `value`, result) + }) + + t.Run("returns original on no XML in non-strict mode", func(t *testing.T) { + response := "No XML here." + result, err := ParseResponse(response, ParseOptions{ + Format: ResponseFormatXML, + StrictMode: false, + }) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) + + t.Run("returns error on no XML in strict mode", func(t *testing.T) { + response := "No XML here." + _, err := ParseResponse(response, ParseOptions{ + Format: ResponseFormatXML, + StrictMode: true, + }) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoXMLFound) + }) + + t.Run("handles self-closing XML tags", func(t *testing.T) { + response := `` + result, err := ParseResponse(response, ParseOptions{Format: ResponseFormatXML}) + + require.NoError(t, err) + assert.Equal(t, response, result) + }) +} + +func TestExtractCodeBlock(t *testing.T) { + t.Run("extracts json code block", func(t *testing.T) { + response := "```json\n{\"key\": \"value\"}\n```" + result := extractCodeBlock(response, "json") + assert.Equal(t, `{"key": "value"}`, result) + }) + + t.Run("extracts xml code block", func(t *testing.T) { + response := "```xml\nvalue\n```" + result := extractCodeBlock(response, "xml") + assert.Equal(t, `value`, result) + }) + + t.Run("returns empty for no match", func(t *testing.T) { + response := "```python\nprint('hello')\n```" + result := extractCodeBlock(response, "json") + assert.Empty(t, result) + }) + + t.Run("handles code block without newlines", func(t *testing.T) { + response := "```json{\"key\": \"value\"}```" + result := extractCodeBlock(response, "json") + assert.Equal(t, `{"key": "value"}`, result) + }) +} + +func TestFindJSONBoundaries(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple object", + input: `{"key": "value"}`, + expected: `{"key": "value"}`, + }, + { + name: "simple array", + input: `["a", "b"]`, + expected: `["a", "b"]`, + }, + { + name: "nested objects", + input: `{"outer": {"inner": "value"}}`, + expected: `{"outer": {"inner": "value"}}`, + }, + { + name: "with preamble", + input: `Here is JSON: {"key": "value"}`, + expected: `{"key": "value"}`, + }, + { + name: "with epilogue", + input: `{"key": "value"} Hope this helps!`, + expected: `{"key": "value"}`, + }, + { + name: "with braces in string", + input: `{"msg": "use { and } carefully"}`, + expected: `{"msg": "use { and } carefully"}`, + }, + { + name: "no JSON", + input: `Just plain text`, + expected: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findJSONBoundaries(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFindXMLBoundaries(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple element", + input: `value`, + expected: `value`, + }, + { + name: "with declaration", + input: ``, + expected: ``, + }, + { + name: "with preamble", + input: `Here is XML: test`, + expected: `test`, + }, + { + name: "no XML", + input: `No XML here`, + expected: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findXMLBoundaries(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsValidJSON(t *testing.T) { + assert.True(t, isValidJSON(`{"key": "value"}`)) + assert.True(t, isValidJSON(`["a", "b"]`)) + assert.True(t, isValidJSON(`null`)) + assert.True(t, isValidJSON(`123`)) + assert.True(t, isValidJSON(`"string"`)) + + assert.False(t, isValidJSON(`{invalid}`)) + assert.False(t, isValidJSON(`not json`)) + assert.False(t, isValidJSON(``)) +} + +func TestIsValidXML(t *testing.T) { + assert.True(t, isValidXML(`value`)) + assert.True(t, isValidXML(``)) + assert.True(t, isValidXML(``)) + + assert.False(t, isValidXML(`not xml`)) + assert.False(t, isValidXML(``)) + assert.False(t, isValidXML(``)) +} diff --git a/internal/llm/registry.go b/internal/llm/registry.go new file mode 100644 index 0000000..9246680 --- /dev/null +++ b/internal/llm/registry.go @@ -0,0 +1,202 @@ +// Package llm provides a unified interface for LLM providers. +package llm + +import ( + "fmt" + "sort" + "strings" +) + +// rawProviderFactory creates a RawProvider instance. +type rawProviderFactory func(cfg Config) (RawProvider, error) + +var providers = make(map[string]rawProviderFactory) +var providerMeta = make(map[string]ProviderInfo) + +// RegisterProvider registers a provider factory. +// Called by provider packages in their init() functions. +func RegisterProvider(name string, factory rawProviderFactory, info ProviderInfo) { + providers[name] = factory + providerMeta[name] = info +} + +// New creates a new LLM provider based on the configuration. +// Returns an error if the provider is not available (CLI not installed, API key missing, etc.) +// The returned Provider automatically handles response parsing. +func New(cfg Config) (Provider, error) { + factory, ok := providers[cfg.Provider] + if !ok { + return nil, fmt.Errorf("unknown provider: %s (available: %s)", cfg.Provider, availableProviders()) + } + rawProvider, err := factory(cfg) + if err != nil { + return nil, err + } + return wrapWithParser(rawProvider), nil +} + +// GetProviderInfo returns metadata for a provider. +func GetProviderInfo(name string) *ProviderInfo { + info, ok := providerMeta[name] + if !ok { + return nil + } + return &info +} + +// ListProviders returns info for all registered providers. +func ListProviders() []ProviderInfo { + result := make([]ProviderInfo, 0, len(providerMeta)) + for _, info := range providerMeta { + result = append(result, info) + } + return result +} + +func availableProviders() string { + names := make([]string, 0, len(providers)) + for name := range providers { + names = append(names, name) + } + return strings.Join(names, ", ") +} + +// GetProviderOptions returns a list of display names for all registered providers. +// Results are sorted alphabetically. If includeSkip is true, "Skip" is appended. +func GetProviderOptions(includeSkip bool) []string { + result := make([]string, 0, len(providerMeta)+1) + for _, info := range providerMeta { + result = append(result, info.DisplayName) + } + sort.Strings(result) + if includeSkip { + result = append(result, "Skip") + } + return result +} + +// GetProviderByDisplayName returns provider info by display name. +func GetProviderByDisplayName(displayName string) *ProviderInfo { + for _, info := range providerMeta { + if info.DisplayName == displayName { + infoCopy := info + return &infoCopy + } + } + return nil +} + +// GetModelOptions returns model display options for a provider. +// Format: "DisplayName - Description (recommended)" for recommended models. +func GetModelOptions(providerName string) []string { + info := GetProviderInfo(providerName) + if info == nil || len(info.Models) == 0 { + return nil + } + + result := make([]string, 0, len(info.Models)) + for _, model := range info.Models { + option := model.DisplayName + if model.Description != "" { + option += " - " + model.Description + } + if model.Recommended { + option += " (recommended)" + } + result = append(result, option) + } + return result +} + +// GetModelIDFromOption extracts the model ID from a display option. +func GetModelIDFromOption(providerName, option string) string { + info := GetProviderInfo(providerName) + if info == nil { + return "" + } + + for _, model := range info.Models { + displayOption := model.DisplayName + if model.Description != "" { + displayOption += " - " + model.Description + } + if model.Recommended { + displayOption += " (recommended)" + } + if displayOption == option { + return model.ID + } + } + return "" +} + +// GetDefaultModelOption returns the recommended model display option for a provider. +func GetDefaultModelOption(providerName string) string { + info := GetProviderInfo(providerName) + if info == nil { + return "" + } + + for _, model := range info.Models { + if model.Recommended { + option := model.DisplayName + if model.Description != "" { + option += " - " + model.Description + } + option += " (recommended)" + return option + } + } + + // Fall back to first model if no recommended + if len(info.Models) > 0 { + model := info.Models[0] + option := model.DisplayName + if model.Description != "" { + option += " - " + model.Description + } + return option + } + return "" +} + +// RequiresAPIKey returns true if the provider requires an API key. +func RequiresAPIKey(providerName string) bool { + info := GetProviderInfo(providerName) + if info == nil { + return false + } + return info.APIKey.Required +} + +// ValidateAPIKey validates an API key for a provider. +// Returns nil if valid, error with message if invalid. +func ValidateAPIKey(providerName, apiKey string) error { + info := GetProviderInfo(providerName) + if info == nil { + return fmt.Errorf("unknown provider: %s", providerName) + } + + if !info.APIKey.Required { + return nil // No validation needed + } + + if apiKey == "" { + return fmt.Errorf("API key cannot be empty") + } + + if info.APIKey.Prefix != "" && !strings.HasPrefix(apiKey, info.APIKey.Prefix) { + return fmt.Errorf("API key should start with '%s'", info.APIKey.Prefix) + } + + return nil +} + +// GetAPIKeyEnvVar returns the environment variable name for the provider's API key. +func GetAPIKeyEnvVar(providerName string) string { + info := GetProviderInfo(providerName) + if info == nil { + return "" + } + return info.APIKey.EnvVarName +} diff --git a/internal/llm/wrapper.go b/internal/llm/wrapper.go new file mode 100644 index 0000000..7314d50 --- /dev/null +++ b/internal/llm/wrapper.go @@ -0,0 +1,32 @@ +package llm + +import "context" + +// parsedProvider wraps a RawProvider with automatic response parsing. +type parsedProvider struct { + raw RawProvider +} + +// wrapWithParser creates a Provider that automatically parses responses. +func wrapWithParser(raw RawProvider) Provider { + return &parsedProvider{raw: raw} +} + +// Execute sends a prompt and returns the parsed response. +func (p *parsedProvider) Execute(ctx context.Context, prompt string, format ResponseFormat) (string, error) { + response, err := p.raw.ExecuteRaw(ctx, prompt, format) + if err != nil { + return "", err + } + return parse(response, format) +} + +// Name returns the provider name. +func (p *parsedProvider) Name() string { + return p.raw.Name() +} + +// Close releases any resources held by the provider. +func (p *parsedProvider) Close() error { + return p.raw.Close() +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 1a41c37..655c5e1 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -9,8 +9,8 @@ import ( "strings" "time" + "github.com/DevSymphony/sym-cli/internal/config" "github.com/DevSymphony/sym-cli/internal/converter" - "github.com/DevSymphony/sym-cli/internal/envutil" "github.com/DevSymphony/sym-cli/internal/git" "github.com/DevSymphony/sym-cli/internal/llm" "github.com/DevSymphony/sym-cli/internal/policy" @@ -34,17 +34,20 @@ func ConvertPolicyWithLLM(userPolicyPath, codePolicyPath string) error { return fmt.Errorf("failed to parse user policy: %w", err) } - // Setup LLM client (backend auto-selection via @llm) - llmClient := llm.NewClient( - llm.WithTimeout(30 * time.Second), - ) + // Setup LLM provider + cfg := llm.LoadConfig() + llmProvider, err := llm.New(cfg) + if err != nil { + return fmt.Errorf("failed to create LLM provider: %w", err) + } + defer llmProvider.Close() // Create converter with output directory outputDir := filepath.Dir(codePolicyPath) - conv := converter.NewConverter(llmClient, outputDir) + conv := converter.NewConverter(llmProvider, outputDir) - // Setup context with timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30*10)*time.Second) + // Setup context with timeout (10 minutes to match validator) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() fmt.Fprintf(os.Stderr, "Converting %d rules...\n", len(userPolicy.Rules)) @@ -122,8 +125,9 @@ func (s *Server) Start() error { // Only try to load policies if we have a directory if dir != "" { // Try to load user-policy.json for natural language descriptions - // First check .env for POLICY_PATH, otherwise use default - userPolicyPath := envutil.GetPolicyPath() + // Get policy path from config.json + projectCfg, _ := config.LoadProjectConfig() + userPolicyPath := projectCfg.PolicyPath if userPolicyPath == "" { userPolicyPath = filepath.Join(dir, "user-policy.json") } else if !filepath.IsAbs(userPolicyPath) { @@ -471,20 +475,20 @@ func (s *Server) handleValidateCode(ctx context.Context, session *sdkmcp.ServerS }, nil } - var llmClient *llm.Client - if session != nil { - // MCP mode: use host LLM via sampling - llmClient = llm.NewClient(llm.WithMCPSession(session)) - fmt.Fprintf(os.Stderr, "✓ Using host LLM via MCP sampling\n") - } else { - // Auto mode: use configured LLM backend (CLI/API) - llmClient = llm.NewClient() - fmt.Fprintf(os.Stderr, "✓ Using configured LLM backend\n") + // Use configured LLM provider + llmCfg := llm.LoadConfig() + llmProvider, err := llm.New(llmCfg) + if err != nil { + return nil, &RPCError{ + Code: -32000, + Message: fmt.Sprintf("failed to create LLM provider: %v", err), + } } + fmt.Fprintf(os.Stderr, "✓ Using LLM provider: %s\n", llmProvider.Name()) // Create unified validator that handles all engines + RBAC v := validator.NewValidator(validationPolicy, false) // verbose=false for MCP - v.SetLLMClient(llmClient) + v.SetLLMProvider(llmProvider) defer func() { _ = v.Close() // Ignore close error in MCP context }() @@ -595,8 +599,7 @@ func (s *Server) getValidationPolicy() (*schema.CodePolicy, error) { // needsConversion checks if user policy needs to be converted to code policy. // Returns true if: // 1. code-policy.json doesn't exist, OR -// 2. user policy has more rules than code policy (indicating new rules added), OR -// 3. user policy has rule IDs that don't exist in code policy +// 2. user policy has rule IDs that don't exist in code policy (after extracting source ID) func (s *Server) needsConversion(codePolicyPath string) bool { // If no code policy exists, conversion is needed if s.codePolicy == nil { @@ -608,19 +611,17 @@ func (s *Server) needsConversion(codePolicyPath string) bool { return false } - // Check if user policy has more rules - if len(s.userPolicy.Rules) > len(s.codePolicy.Rules) { - return true - } - - // Check if all user policy rule IDs exist in code policy - codePolicyRuleIDs := make(map[string]bool) + // Extract source rule IDs from code policy + // code-policy rules have IDs like "FMT-001-eslint", we extract "FMT-001" + codePolicySourceIDs := make(map[string]bool) for _, rule := range s.codePolicy.Rules { - codePolicyRuleIDs[rule.ID] = true + sourceID := extractSourceRuleID(rule.ID) + codePolicySourceIDs[sourceID] = true } + // Check if all user policy rule IDs have corresponding code policy rules for _, userRule := range s.userPolicy.Rules { - if !codePolicyRuleIDs[userRule.ID] { + if !codePolicySourceIDs[userRule.ID] { // Found a user rule that doesn't exist in code policy return true } @@ -629,6 +630,19 @@ func (s *Server) needsConversion(codePolicyPath string) bool { return false } +// extractSourceRuleID extracts the original user-policy rule ID from a code-policy rule ID. +// For example: "FMT-001-eslint" -> "FMT-001" +func extractSourceRuleID(codePolicyRuleID string) string { + // Known linter suffixes that are appended during conversion (see converter.go:179) + linterSuffixes := []string{"-eslint", "-prettier", "-tsc", "-pylint", "-checkstyle", "-pmd", "-llm-validator"} + for _, suffix := range linterSuffixes { + if strings.HasSuffix(codePolicyRuleID, suffix) { + return strings.TrimSuffix(codePolicyRuleID, suffix) + } + } + return codePolicyRuleID +} + // convertUserPolicy converts user policy to code policy using LLM. // This is a wrapper that calls the shared conversion logic. func (s *Server) convertUserPolicy(userPolicyPath, codePolicyPath string) error { diff --git a/internal/server/server.go b/internal/server/server.go index 1f2d03a..e77d87c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,8 +13,8 @@ import ( "strings" "time" + "github.com/DevSymphony/sym-cli/internal/config" "github.com/DevSymphony/sym-cli/internal/converter" - "github.com/DevSymphony/sym-cli/internal/envutil" "github.com/DevSymphony/sym-cli/internal/llm" "github.com/DevSymphony/sym-cli/internal/policy" "github.com/DevSymphony/sym-cli/internal/roles" @@ -98,10 +98,11 @@ func (s *Server) corsMiddleware(next http.Handler) http.Handler { // hasPermissionForRole checks if a role has a specific permission func (s *Server) hasPermissionForRole(role, permission string) (bool, error) { - // Load policy to check RBAC permissions - policyPath := envutil.GetPolicyPath() - if policyPath == "" { - policyPath = ".sym/user-policy.json" + // Load policy path from config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := ".sym/user-policy.json" + if projectCfg != nil && projectCfg.PolicyPath != "" { + policyPath = projectCfg.PolicyPath } policyData, err := policy.LoadPolicy(policyPath) @@ -353,8 +354,9 @@ func (s *Server) handlePolicy(w http.ResponseWriter, r *http.Request) { // handleGetPolicy returns the current policy func (s *Server) handleGetPolicy(w http.ResponseWriter, r *http.Request) { - // Get policy path from .sym/.env (or use default) - policyPath := envutil.GetPolicyPath() + // Get policy path from .sym/config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := projectCfg.PolicyPath if policyPath == "" { policyPath = ".sym/user-policy.json" } @@ -404,8 +406,9 @@ func (s *Server) handleSavePolicy(w http.ResponseWriter, r *http.Request) { return } - // Get policy path from .sym/.env (or use default) - policyPath := envutil.GetPolicyPath() + // Get policy path from .sym/config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := projectCfg.PolicyPath if policyPath == "" { policyPath = ".sym/user-policy.json" } @@ -435,14 +438,11 @@ func (s *Server) handleSavePolicy(w http.ResponseWriter, r *http.Request) { func (s *Server) handlePolicyPath(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - // Load policy path from .sym/.env - policyPath := envutil.GetPolicyPath() + // Load policy path from .sym/config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := projectCfg.PolicyPath if policyPath == "" { - // Default to .sym/user-policy.json if not set policyPath = ".sym/user-policy.json" - fmt.Printf("No POLICY_PATH in .sym/.env, using default: %s\n", policyPath) - } else { - fmt.Printf("Loaded POLICY_PATH from .sym/.env: %s\n", policyPath) } w.Header().Set("Content-Type", "application/json") @@ -489,8 +489,9 @@ func (s *Server) handleSetPolicyPath(w http.ResponseWriter, r *http.Request) { fmt.Printf("Received policy path from client: '%s' (length: %d)\n", req.PolicyPath, len(req.PolicyPath)) - // Get current policy path from .env - oldPolicyPath := envutil.GetPolicyPath() + // Get current policy path from config.json + projectCfg, _ := config.LoadProjectConfig() + oldPolicyPath := projectCfg.PolicyPath if oldPolicyPath == "" { oldPolicyPath = ".sym/user-policy.json" // default } @@ -532,14 +533,15 @@ func (s *Server) handleSetPolicyPath(w http.ResponseWriter, r *http.Request) { } } - // Save to .sym/.env file - fmt.Printf("Saving policy path to .sym/.env: %s\n", req.PolicyPath) - if err := envutil.SaveKeyToEnvFile(filepath.Join(".sym", ".env"), "POLICY_PATH", req.PolicyPath); err != nil { + // Save to .sym/config.json + fmt.Printf("Saving policy path to config.json: %s\n", req.PolicyPath) + projectCfg.PolicyPath = req.PolicyPath + if err := config.SaveProjectConfig(projectCfg); err != nil { fmt.Printf("Failed to save policy path: %v\n", err) http.Error(w, fmt.Sprintf("Failed to save policy path: %v", err), http.StatusInternalServerError) return } - fmt.Printf("Policy path saved successfully to .sym/.env\n") + fmt.Printf("Policy path saved successfully to config.json\n") w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]string{ @@ -660,8 +662,9 @@ func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) { fmt.Println("Starting policy conversion...") - // Get policy path from .env - policyPath := envutil.GetPolicyPath() + // Get policy path from config.json + projectCfg, _ := config.LoadProjectConfig() + policyPath := projectCfg.PolicyPath if policyPath == "" { policyPath = ".sym/user-policy.json" } @@ -678,17 +681,20 @@ func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) { // Determine output directory (same as input file) outputDir := filepath.Dir(policyPath) - // Setup LLM client (backend auto-selection via @llm) - timeout := 30 * time.Second - llmClient := llm.NewClient( - llm.WithTimeout(timeout), - ) + // Setup LLM provider + llmCfg := llm.LoadConfig() + llmProvider, err := llm.New(llmCfg) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create LLM provider: %v", err), http.StatusInternalServerError) + return + } + defer llmProvider.Close() - // Create converter with LLM client and output directory - conv := converter.NewConverter(llmClient, outputDir) + // Create converter with LLM provider and output directory + conv := converter.NewConverter(llmProvider, outputDir) - // Setup context with timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30*10)*time.Second) + // Setup context with timeout (10 minutes to match validator) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() // Convert using new API diff --git a/internal/ui/colors.go b/internal/ui/colors.go new file mode 100644 index 0000000..0685368 --- /dev/null +++ b/internal/ui/colors.go @@ -0,0 +1,114 @@ +package ui + +import ( + "fmt" + "os" + + "golang.org/x/term" +) + +// ANSI color codes +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Cyan = "\033[36m" + Bold = "\033[1m" +) + +// isTTY checks if stdout is a terminal +func isTTY() bool { + return term.IsTerminal(int(os.Stdout.Fd())) +} + +// colorize applies color only if output is a TTY +func colorize(color, msg string) string { + if !isTTY() { + return msg + } + return color + msg + Reset +} + +// OK formats a success message with [OK] prefix in green +func OK(msg string) string { + prefix := colorize(Green, "[OK]") + return fmt.Sprintf("%s %s", prefix, msg) +} + +// Error formats an error message with [ERROR] prefix in red +func Error(msg string) string { + prefix := colorize(Red, "[ERROR]") + return fmt.Sprintf("%s %s", prefix, msg) +} + +// Warn formats a warning message with [WARN] prefix in yellow +func Warn(msg string) string { + prefix := colorize(Yellow, "[WARN]") + return fmt.Sprintf("%s %s", prefix, msg) +} + +// Info formats an info message with [INFO] prefix in blue +func Info(msg string) string { + prefix := colorize(Blue, "[INFO]") + return fmt.Sprintf("%s %s", prefix, msg) +} + +// Title formats a section title in bold cyan +func Title(msg string) string { + prefix := colorize(Bold+Cyan, fmt.Sprintf("[%s]", msg)) + return prefix +} + +// TitleWithDesc formats a section title with description +func TitleWithDesc(title, desc string) string { + prefix := colorize(Bold+Cyan, fmt.Sprintf("[%s]", title)) + return fmt.Sprintf("%s %s", prefix, desc) +} + +// Done formats a completion message with [DONE] prefix in green +func Done(msg string) string { + prefix := colorize(Green+Bold, "[DONE]") + return fmt.Sprintf("%s %s", prefix, msg) +} + +// PrintOK prints a success message +func PrintOK(msg string) { + fmt.Println(OK(msg)) +} + +// PrintError prints an error message +func PrintError(msg string) { + fmt.Println(Error(msg)) +} + +// PrintWarn prints a warning message +func PrintWarn(msg string) { + fmt.Println(Warn(msg)) +} + +// PrintInfo prints an info message +func PrintInfo(msg string) { + fmt.Println(Info(msg)) +} + +// PrintTitle prints a section title +func PrintTitle(title, desc string) { + fmt.Println(TitleWithDesc(title, desc)) +} + +// PrintDone prints a completion message +func PrintDone(msg string) { + fmt.Println(Done(msg)) +} + +// Indent returns the message with indentation +func Indent(msg string) string { + return " " + msg +} + +// PrintIndent prints an indented message +func PrintIndent(msg string) { + fmt.Println(Indent(msg)) +} diff --git a/internal/validator/llm_validator.go b/internal/validator/llm_validator.go index b347ac8..71a0db6 100644 --- a/internal/validator/llm_validator.go +++ b/internal/validator/llm_validator.go @@ -2,7 +2,9 @@ package validator import ( "context" + "encoding/json" "fmt" + "runtime" "strings" "sync" @@ -30,15 +32,15 @@ type ValidationResult struct { // This validator is specifically for Git diff validation. // For regular file validation, use Validator which orchestrates all engines including LLM. type LLMValidator struct { - client *llm.Client + provider llm.Provider policy *schema.CodePolicy validator *Validator } // NewLLMValidator creates a new LLM validator -func NewLLMValidator(client *llm.Client, policy *schema.CodePolicy) *LLMValidator { +func NewLLMValidator(provider llm.Provider, policy *schema.CodePolicy) *LLMValidator { return &LLMValidator{ - client: client, + provider: provider, policy: policy, validator: NewValidator(policy, false), // Use main validator for orchestration } @@ -47,6 +49,7 @@ func NewLLMValidator(client *llm.Client, policy *schema.CodePolicy) *LLMValidato // Validate validates git changes against LLM-based rules. // This method is for diff-based validation (pre-commit hooks, PR validation). // For regular file validation, use validator.Validate() which orchestrates all engines. +// Concurrency is limited to CPU count to prevent CPU spike. func (v *LLMValidator) Validate(ctx context.Context, changes []GitChange) (*ValidationResult, error) { result := &ValidationResult{ Violations: make([]Violation, 0), @@ -59,9 +62,20 @@ func (v *LLMValidator) Validate(ctx context.Context, changes []GitChange) (*Vali } // Check each change against LLM rules using goroutines for parallel processing + // Limit concurrency to prevent resource exhaustion + // Use CPU/4, minimum 2, maximum 4 to balance performance and stability var wg sync.WaitGroup var mu sync.Mutex + maxConcurrent := runtime.NumCPU() / 4 + if maxConcurrent < 2 { + maxConcurrent = 2 + } + if maxConcurrent > 4 { + maxConcurrent = 4 + } + sem := make(chan struct{}, maxConcurrent) + for _, change := range changes { if change.Status == "D" { continue // Skip deleted files @@ -77,7 +91,7 @@ func (v *LLMValidator) Validate(ctx context.Context, changes []GitChange) (*Vali continue } - // Validate against each LLM rule in parallel + // Validate against each LLM rule in parallel with concurrency limit for _, rule := range llmRules { mu.Lock() result.Checked++ @@ -87,6 +101,10 @@ func (v *LLMValidator) Validate(ctx context.Context, changes []GitChange) (*Vali go func(ch GitChange, lines []string, r schema.PolicyRule) { defer wg.Done() + // Acquire semaphore + sem <- struct{}{} + defer func() { <-sem }() + violation, err := v.CheckRule(ctx, ch, lines, r) if err != nil { // Log error but continue @@ -191,8 +209,9 @@ Response: Analyze the code and determine if it violates the rule. Respond with JSON only.`, change.FilePath, rule.Desc, codeSnippet) - // Call LLM with low reasoning (needs thought for code validation) - response, err := v.client.Request(systemPrompt, userPrompt).Execute(ctx) + // Call LLM + prompt := systemPrompt + "\n\n" + userPrompt + response, err := v.provider.Execute(ctx, prompt, llm.JSON) if err != nil { return nil, err } @@ -284,45 +303,14 @@ func parseValidationResponse(response string) validationResponse { return result } -// parseJSON parses JSON string into the target struct +// parseJSON parses JSON string into the target struct using encoding/json func parseJSON(jsonStr string, target interface{}) error { - decoder := strings.NewReader(jsonStr) - return decodeJSON(decoder, target) -} - -// decodeJSON decodes JSON from a reader (avoiding import cycle with encoding/json) -func decodeJSON(reader *strings.Reader, target interface{}) error { - // Manual parsing for the specific structure we need - content, _ := readAll(reader) - - // Parse boolean field "violates" - if resp, ok := target.(*jsonValidationResponse); ok { - resp.Violates = strings.Contains(strings.ToLower(content), `"violates":true`) || - strings.Contains(strings.ToLower(content), `"violates": true`) - - resp.Confidence = extractJSONField(content, "confidence") - resp.Description = extractJSONField(content, "description") - resp.Suggestion = extractJSONField(content, "suggestion") + if err := json.Unmarshal([]byte(jsonStr), target); err != nil { + return fmt.Errorf("failed to parse JSON response: %w", err) } - return nil } -func readAll(reader *strings.Reader) (string, error) { - var builder strings.Builder - buf := make([]byte, 1024) - for { - n, err := reader.Read(buf) - if n > 0 { - builder.Write(buf[:n]) - } - if err != nil { - break - } - } - return builder.String(), nil -} - // parseValidationResponseFallback is used when JSON parsing fails func parseValidationResponseFallback(response string) validationResponse { result := validationResponse{ diff --git a/internal/validator/validator.go b/internal/validator/validator.go index b3294a8..86cdbda 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -43,7 +43,7 @@ type Validator struct { selector *FileSelector ctx context.Context ctxCancel context.CancelFunc - llmClient *llm.Client + llmProvider llm.Provider } // NewValidator creates a new adapter-based validator @@ -65,7 +65,7 @@ func NewValidator(policy *schema.CodePolicy, verbose bool) *Validator { selector: NewFileSelector(workDir), ctx: ctx, ctxCancel: cancel, - llmClient: nil, + llmProvider: nil, } } @@ -84,13 +84,13 @@ func NewValidatorWithWorkDir(policy *schema.CodePolicy, verbose bool, workDir st selector: NewFileSelector(workDir), ctx: ctx, ctxCancel: cancel, - llmClient: nil, + llmProvider: nil, } } -// SetLLMClient sets the LLM client for this validator -func (v *Validator) SetLLMClient(client *llm.Client) { - v.llmClient = client +// SetLLMProvider sets the LLM provider for this validator +func (v *Validator) SetLLMProvider(provider llm.Provider) { + v.llmProvider = provider } // executeRule executes a rule using the appropriate adapter @@ -176,8 +176,8 @@ func (v *Validator) executeRule(engineName string, rule schema.PolicyRule, files // executeLLMRule executes an LLM-based rule func (v *Validator) executeLLMRule(rule schema.PolicyRule, files []string) ([]Violation, error) { - if v.llmClient == nil { - return nil, fmt.Errorf("LLM client not configured") + if v.llmProvider == nil { + return nil, fmt.Errorf("LLM provider not configured") } // Validate required fields for LLM validator @@ -230,7 +230,8 @@ Does this code violate the convention?`, file, rule.Desc, string(content)) // Call LLM fileStartTime := time.Now() - response, err := v.llmClient.Request(systemPrompt, userPrompt).Execute(v.ctx) + prompt := systemPrompt + "\n\n" + userPrompt + response, err := v.llmProvider.Execute(v.ctx, prompt, llm.Text) fileExecMs := time.Since(fileStartTime).Milliseconds() // Record response in consolidated output @@ -464,13 +465,14 @@ func (v *Validator) ValidateChanges(ctx context.Context, changes []GitChange) (* // validateLLMChanges validates changes using LLM in parallel func (v *Validator) validateLLMChanges(ctx context.Context, changes []GitChange, rule schema.PolicyRule, result *ValidationResult) { - if v.llmClient == nil { + if v.llmProvider == nil { + fmt.Fprintf(os.Stderr, "⚠️ LLM provider not configured, skipping LLM validation for rule %s\n", rule.ID) return } var wg sync.WaitGroup var mu sync.Mutex - llmValidator := NewLLMValidator(v.llmClient, v.policy) + llmValidator := NewLLMValidator(v.llmProvider, v.policy) for _, change := range changes { if change.Status == "D" { @@ -500,6 +502,13 @@ func (v *Validator) validateLLMChanges(ctx context.Context, changes []GitChange, violation, err := llmValidator.CheckRule(ctx, ch, lines, r) if err != nil { + mu.Lock() + result.Errors = append(result.Errors, ValidationError{ + RuleID: r.ID, + Engine: "llm-validator", + Message: fmt.Sprintf("failed to check rule: %v", err), + }) + mu.Unlock() return } diff --git a/npm/package.json b/npm/package.json index 7d5bb41..94522a4 100644 --- a/npm/package.json +++ b/npm/package.json @@ -1,6 +1,6 @@ { "name": "@dev-symphony/sym", - "version": "0.1.5", + "version": "0.1.6", "description": "Symphony - LLM-friendly convention linter for AI coding assistants", "keywords": [ "mcp", diff --git a/tests/e2e/full_workflow_test.go b/tests/e2e/full_workflow_test.go index 696eb28..b1a9e83 100644 --- a/tests/e2e/full_workflow_test.go +++ b/tests/e2e/full_workflow_test.go @@ -73,12 +73,12 @@ func TestE2E_FullWorkflow(t *testing.T) { // ========== STEP 2: Convert natural language to structured policy ========== t.Log("STEP 2: Converting user policy using LLM") - client := llm.NewClient( - llm.WithTimeout(30 * time.Second), - ) + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") outputDir := filepath.Join(testDir, ".sym") - conv := converter.NewConverter(client, outputDir) + conv := converter.NewConverter(provider, outputDir) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -168,7 +168,7 @@ func ProcessData(data string) error { // ========== STEP 4: Validate generated code ========== t.Log("STEP 4: Validating generated code against conventions") - llmValidator := validator.NewLLMValidator(client, &convertedPolicy) + llmValidator := validator.NewLLMValidator(provider, &convertedPolicy) // Validate BAD code t.Log("STEP 4a: Validating BAD code (should find violations)") @@ -332,8 +332,10 @@ func TestE2E_CodeGenerationFeedbackLoop(t *testing.T) { }, } - client := llm.NewClient() - v := validator.NewLLMValidator(client, policy) + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") + v := validator.NewLLMValidator(provider, policy) ctx := context.Background() // Iteration 1: Bad code diff --git a/tests/e2e/mcp_integration_test.go b/tests/e2e/mcp_integration_test.go index 3767d7e..f955632 100644 --- a/tests/e2e/mcp_integration_test.go +++ b/tests/e2e/mcp_integration_test.go @@ -6,7 +6,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/DevSymphony/sym-cli/internal/llm" "github.com/DevSymphony/sym-cli/internal/validator" @@ -140,13 +139,13 @@ func TestMCP_ValidateAIGeneratedCode(t *testing.T) { policy, err := loadPolicy(policyPath) require.NoError(t, err) - // Create LLM client - client := llm.NewClient( - llm.WithTimeout(30 * time.Second), - ) + // Create LLM provider + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") // Create validator - v := validator.NewLLMValidator(client, policy) + v := validator.NewLLMValidator(provider, policy) ctx := context.Background() // Test 1: Validate BAD code (should find multiple violations) @@ -257,7 +256,7 @@ func TestMCP_ValidateAIGeneratedCode(t *testing.T) { Rules: filterRulesByCategory(policy.Rules, "security"), } - securityValidator := validator.NewLLMValidator(client, securityPolicy) + securityValidator := validator.NewLLMValidator(provider, securityPolicy) // Code with security violation (format as git diff with + prefix) codeWithSecurityIssue := `+const apiKey = "sk-1234567890abcdef"; // Hardcoded secret @@ -378,8 +377,10 @@ func TestMCP_EndToEndWorkflow(t *testing.T) { // Step 4: Validate generated code t.Log("STEP 4: Validating AI-generated code") - client := llm.NewClient() - v := validator.NewLLMValidator(client, policy) + llmCfg := llm.LoadConfig() + llmProvider, err := llm.New(llmCfg) + require.NoError(t, err, "LLM provider creation should succeed") + v := validator.NewLLMValidator(llmProvider, policy) result, err := v.Validate(context.Background(), []validator.GitChange{ {FilePath: "auth.js", Diff: generatedCode}, diff --git a/tests/e2e/validator_test.go b/tests/e2e/validator_test.go index 23e890a..f5bc81f 100644 --- a/tests/e2e/validator_test.go +++ b/tests/e2e/validator_test.go @@ -30,11 +30,13 @@ func TestE2E_ValidatorWithPolicy(t *testing.T) { require.NoError(t, err, "Failed to load policy") require.NotEmpty(t, policy.Rules, "Policy should have rules") - // Create LLM client - client := llm.NewClient() + // Create LLM provider + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") // Create validator - v := validator.NewLLMValidator(client, policy) + v := validator.NewLLMValidator(provider, policy) // Create a test change (simulating git diff output) changes := []validator.GitChange{ @@ -82,11 +84,13 @@ func TestE2E_ValidatorWithGoodCode(t *testing.T) { policy, err := loadPolicy(".sym/code-policy.json") require.NoError(t, err) - // Create LLM client - client := llm.NewClient() + // Create LLM provider + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") // Create validator - v := validator.NewLLMValidator(client, policy) + v := validator.NewLLMValidator(provider, policy) // Create a test change with good code changes := []validator.GitChange{ @@ -181,11 +185,13 @@ func TestE2E_ValidatorFilter(t *testing.T) { policy, err := loadPolicy(".sym/code-policy.json") require.NoError(t, err) - // Create LLM client - client := llm.NewClient() + // Create LLM provider + cfg := llm.LoadConfig() + provider, err := llm.New(cfg) + require.NoError(t, err, "LLM provider creation should succeed") // Create validator - v := validator.NewLLMValidator(client, policy) + v := validator.NewLLMValidator(provider, policy) // Test with Go file changes := []validator.GitChange{