diff --git a/neo/message/json.go b/neo/message/json.go index 1f1528967..4b2a7f9ec 100644 --- a/neo/message/json.go +++ b/neo/message/json.go @@ -1,9 +1,9 @@ package message import ( - "io" "strings" + "github.com/gin-gonic/gin" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/helper" "github.com/yaoapp/kun/log" @@ -47,6 +47,10 @@ func NewOpenAI(data []byte) *JSON { msg.Done = true break + case strings.Contains(text, `"finish_reason":"stop"`): + msg.Done = true + break + default: msg.Error = text } @@ -189,7 +193,13 @@ func (json *JSON) IsDone() bool { } // Write the message -func (json *JSON) Write(w io.Writer) bool { +func (json *JSON) Write(w gin.ResponseWriter) bool { + + defer func() { + if r := recover(); r != nil { + log.Error("Write JSON Message Error: %s", r) + } + }() data, err := jsoniter.Marshal(json.Message) if err != nil { @@ -205,7 +215,7 @@ func (json *JSON) Write(w io.Writer) bool { log.Error("%s", err.Error()) return false } - + w.Flush() return true } diff --git a/neo/neo.go b/neo/neo.go index d6d030222..a2c4cc67f 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -2,14 +2,12 @@ package neo import ( "fmt" - "io" "net/url" "strings" "github.com/fatih/color" "github.com/gin-gonic/gin" "github.com/google/uuid" - jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/api" "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/process" @@ -172,180 +170,140 @@ func (neo *DSL) API(router *gin.Engine, path string) error { return nil } -// Answer the message -func (neo *DSL) Answer(ctx command.Context, question string, answer Answer) error { - - chanStream := make(chan *message.JSON, 1) - chanError := make(chan error, 1) - content := []byte{} - errorMsg := []byte{} - +// Answer reply the message +func (neo *DSL) Answer(ctx command.Context, question string, c *gin.Context) error { // get the chat messages messages, err := neo.chatMessages(ctx, question) if err != nil { return err } - // check the command - cmd, isCommand := neo.matchCommand(ctx, messages) + clientBreak := make(chan bool, 1) + done := make(chan bool, 1) + content := []byte{} + + // Execute the command or chat with AI in the background go func() { - defer func() { - close(chanStream) - close(chanError) - }() - // execute the command - if isCommand { + // chat with AI + c.Header("Content-Type", "text/event-stream;charset=utf-8") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + // check the command + cmd, isCommand := neo.matchCommand(ctx, messages) + if isCommand { + // execute the command req, err := cmd.NewRequest(ctx, neo.Conversation) if err != nil { - chanError <- err + log.Error("Command with AI error: %s", err.Error()) + done <- true return } - log.Trace("Command with AI: question: %s messages:%v", question, messages) err = req.Run(messages, func(msg *message.JSON) int { - chanStream <- msg + err := neo.send(ctx, msg, messages, content, c) + if err != nil { + c.Status(500) + return 0 // break + } + + // Complete the stream + if msg.IsDone() { + return 0 // break + } return 1 }) if err != nil { - chanError <- err + c.Status(500) + log.Error("Command with AI error: %s", err.Error()) } return } - // chat with AI - log.Trace("Chat with AI: question:%s messages:%v", question, messages) _, ex := neo.AI.ChatCompletionsWith(ctx, messages, neo.Option, func(data []byte) int { - chanStream <- message.NewOpenAI(data) - return 1 - }) - if ex != nil { - chanError <- fmt.Errorf("AI chat error: %s", ex.Message) - } + select { + case <-clientBreak: + return 0 // break + default: - defer neo.saveHistory(ctx.Sid, content, messages) + msg := message.NewOpenAI(data) + if msg == nil { + return 1 // continue success + } - }() + if msg.Error != "" { + neo.send(ctx, msg, messages, content, c) + return 0 // break + } - answer.Header("Content-Type", "text/event-stream;charset=utf-8") - ok := answer.Stream(func(w io.Writer) bool { - select { - case err := <-chanError: - if err != nil { - message.New().Text(err.Error()).Write(w) - } + content = msg.Append(content) + err := neo.send(ctx, msg, messages, content, c) + if err != nil { + c.Status(500) + return 0 // break + } - if len(errorMsg) > 0 { - - var errData openai.ErrorMessage - err := jsoniter.Unmarshal(errorMsg, &errData) - if err == nil { - msg := errData.Error.Message - if msg == "" { - msg = fmt.Sprintf("OpenAI error: %s", errData.Error.Code) - } - message.New().Text(msg).Write(w) - message.New().Done().Write(w) - return false + // Complete the stream + if msg.IsDone() { + done <- true + return 0 // break } - message.New().Text(string(errorMsg)).Write(w) - message.New().Done().Write(w) - return false + return 1 // continue success } + }) - message.New().Done().Write(w) - return false + // Throw the error + if ex != nil { + log.Error("Neo chat error: %s", ex.Message) + c.Status(200) + done <- true + return + } - case msg := <-chanStream: - if msg == nil { - return true - } + // save the history + neo.saveHistory(ctx.Sid, content, messages) + c.Status(200) - if msg.Error != "" { - errorMsg = append(errorMsg, []byte(msg.Error)...) - return true - } + // Complete the stream + done <- true - content = msg.Append(content) - err := neo.write(msg, w, ctx, messages, content) - if err != nil { - log.Warn("Neo write process msg: %v error: %s", msg, err.Error()) - msg.Write(w) - } - - return !msg.IsDone() - - // case <-ctx.Done(): - // if err := ctx.Err(); err != nil { - // message.New().Text(err.Error()).Write(w) - // } - - // if len(errorMsg) > 0 { - - // var errData openai.ErrorMessage - // err := jsoniter.Unmarshal(errorMsg, &errData) - // if err == nil { - // msg := errData.Error.Message - // if msg == "" { - // msg = fmt.Sprintf("OpenAI error: %s", errData.Error.Code) - // } - // message.New().Text(msg).Write(w) - // message.New().Done().Write(w) - // return false - // } - - // message.New().Text(string(errorMsg)).Write(w) - // message.New().Done().Write(w) - // return false - // } - - // message.New().Done().Write(w) - // return false - } - }) + }() - if !ok { - answer.Status(500) + select { + case <-done: + return nil + case <-c.Writer.CloseNotify(): + clientBreak <- true return nil } - answer.Status(200) - return nil } -// prompts get the prompts -func (neo *DSL) prompts() []map[string]interface{} { - prompts := []map[string]interface{}{} - for _, prompt := range neo.Prompts { - message := map[string]interface{}{"role": prompt.Role, "content": prompt.Content} - if prompt.Name != "" { - message["name"] = prompt.Name - } - prompts = append(prompts, message) - } +// Send send the message to the stream +func (neo *DSL) send(ctx command.Context, msg *message.JSON, messages []map[string]interface{}, content []byte, c *gin.Context) error { - return prompts -} - -// after the after hook -func (neo *DSL) write(msg *message.JSON, w io.Writer, ctx command.Context, messages []map[string]interface{}, content []byte) error { + w := c.Writer + // Directly write the message if neo.Write == "" { - msg.Write(w) + ok := msg.Write(c.Writer) + if !ok { + return fmt.Errorf("Stream write error") + } return nil } - args := []interface{}{ctx, messages, msg, string(content)} + // Execute the custom write hook get the response + args := []interface{}{ctx, messages, msg, string(content), w} p, err := process.Of(neo.Write, args...) if err != nil { - log.Error("Neo custom write process error: %s", err.Error()) msg.Write(w) - return nil + return fmt.Errorf("Stream write error: %s", err.Error()) } err = p.WithSID(ctx.Sid).Execute() @@ -355,11 +313,13 @@ func (neo *DSL) write(msg *message.JSON, w io.Writer, ctx command.Context, messa return nil } defer p.Release() + res := p.Value() if res == nil { return fmt.Errorf("Neo custom write return null") } + // Send the custom write response to the stream if messages, ok := res.([]interface{}); ok { for _, new := range messages { if v, ok := new.(map[string]interface{}); ok { @@ -370,7 +330,21 @@ func (neo *DSL) write(msg *message.JSON, w io.Writer, ctx command.Context, messa return nil } - return fmt.Errorf("Neo custom write return not map") + return fmt.Errorf("Neo should return an array of response") +} + +// prompts get the prompts +func (neo *DSL) prompts() []map[string]interface{} { + prompts := []map[string]interface{}{} + for _, prompt := range neo.Prompts { + message := map[string]interface{}{"role": prompt.Role, "content": prompt.Content} + if prompt.Name != "" { + message["name"] = prompt.Name + } + prompts = append(prompts, message) + } + + return prompts } // prepare the messages @@ -493,20 +467,17 @@ func (neo *DSL) getCorsHandlers(router *gin.Engine, path string) ([]gin.HandlerF allowsMap[allow] = true } - router.OPTIONS(path, func(c *gin.Context) { c.AbortWithStatus(204) }) - router.OPTIONS(path+"/commands", func(c *gin.Context) { c.AbortWithStatus(204) }) - router.OPTIONS(path+"/history", func(c *gin.Context) { c.AbortWithStatus(204) }) + router.OPTIONS(path+"/history", neo.optionsHandler) + router.OPTIONS(path+"/commands", neo.optionsHandler) return []gin.HandlerFunc{ func(c *gin.Context) { - referer := c.Request.Referer() + referer := neo.getOrigin(c) if referer != "" { - if !api.IsAllowed(c, allowsMap) { c.JSON(403, gin.H{"message": referer + " not allowed", "code": 403}) c.Abort() return } - url, _ := url.Parse(referer) referer = fmt.Sprintf("%s://%s", url.Scheme, url.Host) c.Writer.Header().Set("Access-Control-Allow-Origin", referer) @@ -519,6 +490,24 @@ func (neo *DSL) getCorsHandlers(router *gin.Engine, path string) ([]gin.HandlerF }, nil } +func (neo *DSL) optionsHandler(c *gin.Context) { + origin := neo.getOrigin(c) + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.AbortWithStatus(204) +} + +func (neo *DSL) getOrigin(c *gin.Context) string { + referer := c.Request.Referer() + origin := c.Request.Header.Get("Origin") + if origin == "" { + origin = referer + } + return origin +} + func (neo *DSL) getGuardHandlers() ([]gin.HandlerFunc, error) { if neo.Guard == "" { diff --git a/neo/process.go b/neo/process.go new file mode 100644 index 000000000..598b40ad8 --- /dev/null +++ b/neo/process.go @@ -0,0 +1,41 @@ +package neo + +import ( + "github.com/gin-gonic/gin" + "github.com/yaoapp/gou/process" + "github.com/yaoapp/kun/exception" + "github.com/yaoapp/yao/neo/message" +) + +func init() { + process.RegisterGroup("neo", map[string]process.Handler{ + "write": ProcessWrite, + }) +} + +// ProcessWrite process the write request +func ProcessWrite(process *process.Process) interface{} { + + process.ValidateArgNums(2) + + w, ok := process.Args[0].(gin.ResponseWriter) + if !ok { + exception.New("The first argument must be a io.Writer", 400).Throw() + return nil + } + + data, ok := process.Args[1].([]interface{}) + if !ok { + exception.New("The second argument must be a Array", 400).Throw() + return nil + } + + for _, new := range data { + if v, ok := new.(map[string]interface{}); ok { + newMsg := message.New().Map(v) + newMsg.Write(w) + } + } + + return nil +}