-
Notifications
You must be signed in to change notification settings - Fork 1
/
server_llm_connector.go
170 lines (141 loc) · 5.68 KB
/
server_llm_connector.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package main
import (
"bytes"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"sync/atomic"
)
// # LLM User Query Struct
//
// This struct is used to represent a user query that is sent to the LLM server.
type LlmUserQuery struct {
ChannelId int `json:"CHANNEL_ID"` // The channel ID.
UserId string `json:"USER_ID"` // The user ID.
Query string `json:"USER_QUERY"` // The user query.
}
// # LLM Model Response Struct
//
// This struct is used to represent the response from the LLM server.
type LlmModelResponse struct {
Reference string `json:"reference_text"` // The reference part of the model response.
Response string `json:"response_text"` // The main model response.
}
// # LLM Connector
//
// This is the main LLM connector struct.
type LlmConnector struct {
ChannelMap ChannelIdConfigMap // A map from channel ID to channel configuration.
LocalDebugMode bool // Indicates if the LLM connector is in local debug mode.
waitingCounter int32 // Indicates the current waiting requests.
}
// # New LLM Connector
//
// This function creates a new LLM connector instance.
func NewLlmConnector(channelMap ChannelIdConfigMap, LocalDebugMode bool) *LlmConnector {
return &LlmConnector{
ChannelMap: channelMap, // Set the channel map.
LocalDebugMode: LocalDebugMode, // Set the local debug mode.
}
}
func (c *LlmConnector) LlmCallback(bot *TaipeionBot, event ChatbotWebhookEvent) error {
// Check if incoming event is text message.
// We do not support non-text messages for now, we may implement it later.
if event.Message.Type != "text" {
log.Println("[PrivateMessageCallback] Received non-text message. Ignoring.")
return nil
}
// Check if channel ID is in the channel map.
if _, ok := c.ChannelMap[event.Destination]; !ok {
log.Printf("[LlmCallback] Channel ID (%d) not found in config. Ignoring.\n", event.Destination)
return nil
}
// Information gathering.
chan_id := event.Destination // Channel ID
userId := event.Source.UserId // User ID
userQuery := event.Message.Text // User query
trigger_word := c.ChannelMap[chan_id].ChannelTriggerPrefix // Trigger word
log.Printf("[LlmCallback] Received user (%s) query on channel (%d): %s\n", userId, chan_id, userQuery)
// Check if the user query starts with the trigger word.
if !strings.HasPrefix(strings.TrimSpace(userQuery), trigger_word) {
log.Printf("[LlmCallback] User query does not start with trigger word (%s). Ignoring.\n", trigger_word)
return nil
}
// Check debug mode.
if c.LocalDebugMode {
log.Println("[LlmCallback] Local debug mode is enabled.")
log.Printf("[LlmCallback] [Debug info] User ID: %s, Channel ID: %d, User Query: %s Trigger Word: %s\n", userId, chan_id, userQuery, trigger_word)
log.Println("[LlmCallback] The following procedure is sending the user query to the LLM server in normal mode.")
log.Println("[LlmCallback] LLM Callback will now exit.")
return nil
}
// Send a friendly message.
err := bot.SendPrivateMessage(userId, fmt.Sprintf("正在處理您的問題,視當前情況大約需要30秒~數分鐘不等\n感謝您的耐心等待!\n(目前排隊: %d)", c.waitingCounter), chan_id)
if err != nil {
return err
}
atomic.AddInt32(&c.waitingCounter, 1) // Add waiting counter by 1.
// Create a new user query.
userQueryPayload := LlmUserQuery{
ChannelId: chan_id,
UserId: userId,
Query: userQuery,
}
// Send the user query to the LLM server.
response, err := c.LlmRequestSender(userQueryPayload)
if err != nil {
log.Println("[LlmCallback] Unable to send user query to LLM server:", err)
return err
}
// Create model response.
concatedResponse := fmt.Sprintf("%s\n\n%s", response.Response, response.Reference)
// PATCH: Replace certain character fo aviod false-positive of SQL injection.
false_positive_chars := []string{"'", ","}
for _, char := range false_positive_chars {
concatedResponse = strings.ReplaceAll(concatedResponse, char, " ")
}
log.Printf("[LlmCallback] Model response for user (%s) on channel (%d): %s\n", userId, chan_id, concatedResponse)
atomic.AddInt32(&c.waitingCounter, -1) // Decrease waiting counter by 1.
return bot.SendPrivateMessage(userId, concatedResponse, chan_id) // Send final result.
}
// # LLM Request Sender
//
// This function sends a user query to the LLM server and returns the response.
func (c *LlmConnector) LlmRequestSender(prompt LlmUserQuery) (LlmModelResponse, error) {
// Serialize the user query.
request_payload, err := json.Marshal(prompt)
if err != nil {
log.Println("[LlmConnector] Unable to serialize user query:", err)
return LlmModelResponse{}, err
}
// Create a new HTTP request.
req, err := http.NewRequest(
"POST",
c.ChannelMap[prompt.ChannelId].ChannelLlmEndpoint,
bytes.NewBuffer(request_payload))
if err != nil {
log.Println("[LlmConnector] Unable to create request:", err)
return LlmModelResponse{}, err
}
// Set the request headers.
req.Header.Set("Content-Type", "application/json")
// Perform the request.
client := &http.Client{} // Create a new HTTP client.
resp, err := client.Do(req) // Perform the request.
if err != nil {
log.Println("[LlmConnector] Unable to perform request:", err)
return LlmModelResponse{}, err
}
defer resp.Body.Close() // Close the response body when done.
// Create a new model response.
modelResp := LlmModelResponse{} // Create a new model response.
err = json.NewDecoder(resp.Body).Decode(&modelResp) // Decode the response body.
if err != nil {
log.Println("[LlmConnector] Unable to decode response:", err)
return LlmModelResponse{}, err
}
// Return the model response.
return modelResp, nil
}