From 7e83613483745ffe6b073504f4427cba8933fd1b Mon Sep 17 00:00:00 2001 From: NineOceans <44770303+LyuLumos@users.noreply.github.com> Date: Fri, 15 Dec 2023 15:19:00 +0000 Subject: [PATCH 1/4] Add support for Google Gemini --- src/providers/google/api.ts | 14 +++++++ src/providers/google/handler.ts | 70 +++++++++++++++++++++++++++++++++ src/providers/google/index.ts | 60 ++++++++++++++++++++++++++++ src/stores/provider.ts | 3 ++ 4 files changed, 147 insertions(+) create mode 100644 src/providers/google/api.ts create mode 100644 src/providers/google/handler.ts create mode 100644 src/providers/google/index.ts diff --git a/src/providers/google/api.ts b/src/providers/google/api.ts new file mode 100644 index 00000000..fed90456 --- /dev/null +++ b/src/providers/google/api.ts @@ -0,0 +1,14 @@ +export interface GoogleFetchPayload { + apiKey: string + body: Record +} + +export const fetchChatCompletion = async(payload: GoogleFetchPayload) => { + const { apiKey, body } = payload || {} + const initOptions = { + headers: { 'Content-Type': 'application/json' }, + method: 'POST', + body: JSON.stringify({ ...body }), + } + return fetch(`https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key=${apiKey}`, initOptions); +} \ No newline at end of file diff --git a/src/providers/google/handler.ts b/src/providers/google/handler.ts new file mode 100644 index 00000000..048ceee9 --- /dev/null +++ b/src/providers/google/handler.ts @@ -0,0 +1,70 @@ +import { fetchChatCompletion } from "./api" +import type { Message } from '@/types/message' +import type { HandlerPayload, Provider } from '@/types/provider' + +export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { + if (payload.botId === 'chat_continuous') + return handleChatCompletion(payload, signal) + if (payload.botId === 'chat_single') + return handleChatCompletion(payload, signal) +} + +export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { + const rapidPromptPayload = { + conversationId: 'temp', + conversationType: 'chat_single', + botId: 'temp', + globalSettings: { + ...globalSettings, + }, + botSettings: {}, + prompt, + messages: { 'prompt': { 'text': prompt } }, + } as unknown as HandlerPayload + const result = await handleChatCompletion(rapidPromptPayload) + if (typeof result === 'string') + return result + return '' +} + +export const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => { + // An array to store the chat messages + const messages: Message[] = [] + + let maxTokens = payload.globalSettings.maxTokens as number + let messageHistorySize = payload.globalSettings.messageHistorySize as number + + // Iterate through the message history + while (messageHistorySize > 0) { + messageHistorySize-- + // Get the last message from the payload + const m = payload.messages.pop() + if (m === undefined) + break + + if (maxTokens - m.content.length < 0) + break + + maxTokens -= m.content.length + messages.unshift(m) + } + + const response = await fetchChatCompletion({ + apiKey: payload.globalSettings.apiKey as string, + body: { + prompt: { + text: messages.map(m => m.content).join('\n'), + } + }, + }) + + if (response.ok) { + const json = await response.json() + console.log('json', json) + const output = json.candidates[0].output || json.statusText || json.status || json + return output as string + } + + const text = await response.text() + throw new Error(`Failed to fetch chat completion: ${text}`) + } diff --git a/src/providers/google/index.ts b/src/providers/google/index.ts new file mode 100644 index 00000000..26d2297d --- /dev/null +++ b/src/providers/google/index.ts @@ -0,0 +1,60 @@ +import { + handlePrompt, + handleRapidPrompt, +} from './handler' +import type { Provider } from '@/types/provider' + +const providerGoogle = () => { + const provider: Provider = { + id: 'provider-google', + icon: 'i-simple-icons-google', // @unocss-include + name: 'Google', + globalSettings: [ + { + key: 'apiKey', + name: 'API Key', + type: 'api-key', + }, + { + key: 'maxTokens', + name: 'Max Tokens', + description: 'The maximum number of tokens to generate in the completion.', + type: 'slider', + min: 0, + max: 32768, + default: 2048, + step: 1, + }, + { + key: 'messageHistorySize', + name: 'Max History Message Size', + description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.', + type: 'slider', + min: 1, + max: 24, + default: 5, + step: 1, + }, + ], + bots: [ + { + id: 'chat_continuous', + type: 'chat_continuous', + name: 'Continuous Chat', + settings: [], + }, + { + id: 'chat_single', + type: 'chat_single', + name: 'Single Chat', + settings: [], + }, + + ], + handlePrompt, + handleRapidPrompt, + } + return provider +} + +export default providerGoogle \ No newline at end of file diff --git a/src/stores/provider.ts b/src/stores/provider.ts index c5f85cee..5b64a9d2 100644 --- a/src/stores/provider.ts +++ b/src/stores/provider.ts @@ -1,13 +1,16 @@ import providerOpenAI from '@/providers/openai' import providerAzure from '@/providers/azure' +import providerGoogle from '@/providers/google' import providerReplicate from '@/providers/replicate' import { allConversationTypes } from '@/types/conversation' import type { BotMeta } from '@/types/app' + export const providerList = [ providerOpenAI(), providerAzure(), providerReplicate(), + providerGoogle(), ] export const providerMetaList = providerList.map(provider => ({ From c2897584465614d39d119a1813993e4eb08d5477 Mon Sep 17 00:00:00 2001 From: NineOceans <44770303+LyuLumos@users.noreply.github.com> Date: Mon, 18 Dec 2023 09:35:25 +0000 Subject: [PATCH 2/4] fix: Change model from PaLM to Google Gemini API --- src/logics/conversation.ts | 6 +- src/providers/google/api.ts | 5 +- src/providers/google/handler.ts | 112 ++++++++++++++++---------------- src/providers/google/index.ts | 10 +++ src/types/message.ts | 2 +- 5 files changed, 74 insertions(+), 61 deletions(-) diff --git a/src/logics/conversation.ts b/src/logics/conversation.ts index 57d0e35f..c5fd94d9 100644 --- a/src/logics/conversation.ts +++ b/src/logics/conversation.ts @@ -34,6 +34,7 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s setLoadingStateByConversationId(conversation.id, true) let providerResponse: PromptResponse + const systemRole = provider.name === 'Google' ? 'user' : 'system' const handlerPayload: HandlerPayload = { conversationId: conversation.id, conversationType: bot.type, @@ -42,7 +43,8 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s botSettings: {}, prompt, messages: [ - ...(conversation.systemInfo ? [{ role: 'system', content: conversation.systemInfo }] : []) as Message[], + ...(conversation.systemInfo ? [{ role: systemRole, content: conversation.systemInfo }] : []) as Message[], + ...(provider.name === 'Google' && conversation.systemInfo ? [{ role: 'model', content: 'OK' }] : []) as Message[], // Google Gemini API currently only support odd number of messages. ...(destr(conversation.mockMessages) || []) as Message[], ...getMessagesByConversationId(conversation.id).map(message => ({ role: message.role, @@ -77,7 +79,7 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s } pushMessageByConversationId(conversation.id, { id: messageId, - role: 'assistant', + role: provider.name === 'Google' ? 'model' : 'assistant', content: typeof providerResponse === 'string' ? providerResponse : '', stream: providerResponse instanceof ReadableStream, dateTime: new Date().getTime(), diff --git a/src/providers/google/api.ts b/src/providers/google/api.ts index fed90456..bbde24b3 100644 --- a/src/providers/google/api.ts +++ b/src/providers/google/api.ts @@ -1,14 +1,15 @@ export interface GoogleFetchPayload { apiKey: string body: Record + model?: string } export const fetchChatCompletion = async(payload: GoogleFetchPayload) => { - const { apiKey, body } = payload || {} + const { apiKey, body, model } = payload || {} const initOptions = { headers: { 'Content-Type': 'application/json' }, method: 'POST', body: JSON.stringify({ ...body }), } - return fetch(`https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key=${apiKey}`, initOptions); + return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions); } \ No newline at end of file diff --git a/src/providers/google/handler.ts b/src/providers/google/handler.ts index 048ceee9..66ba09fa 100644 --- a/src/providers/google/handler.ts +++ b/src/providers/google/handler.ts @@ -2,69 +2,69 @@ import { fetchChatCompletion } from "./api" import type { Message } from '@/types/message' import type { HandlerPayload, Provider } from '@/types/provider' -export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { +export const handlePrompt: Provider['handlePrompt'] = async (payload, signal?: AbortSignal) => { if (payload.botId === 'chat_continuous') return handleChatCompletion(payload, signal) if (payload.botId === 'chat_single') return handleChatCompletion(payload, signal) } -export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { - const rapidPromptPayload = { - conversationId: 'temp', - conversationType: 'chat_single', - botId: 'temp', - globalSettings: { - ...globalSettings, - }, - botSettings: {}, - prompt, - messages: { 'prompt': { 'text': prompt } }, - } as unknown as HandlerPayload - const result = await handleChatCompletion(rapidPromptPayload) - if (typeof result === 'string') - return result - return '' +export const handleRapidPrompt: Provider['handleRapidPrompt'] = async (prompt, globalSettings) => { + const rapidPromptPayload = { + conversationId: 'temp', + conversationType: 'chat_single', + botId: 'temp', + globalSettings: { + ...globalSettings, + model: 'gemini-pro', + }, + botSettings: {}, + prompt, + messages: { contents: [{ role: 'user', parts: [{ text: prompt }] }] }, + } as unknown as HandlerPayload + const result = await handleChatCompletion(rapidPromptPayload) + if (typeof result === 'string') + return result + return '' } export const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => { - // An array to store the chat messages - const messages: Message[] = [] - - let maxTokens = payload.globalSettings.maxTokens as number - let messageHistorySize = payload.globalSettings.messageHistorySize as number - - // Iterate through the message history - while (messageHistorySize > 0) { - messageHistorySize-- - // Get the last message from the payload - const m = payload.messages.pop() - if (m === undefined) - break - - if (maxTokens - m.content.length < 0) - break - - maxTokens -= m.content.length - messages.unshift(m) - } - - const response = await fetchChatCompletion({ - apiKey: payload.globalSettings.apiKey as string, - body: { - prompt: { - text: messages.map(m => m.content).join('\n'), - } - }, - }) + // An array to store the chat messages + const messages: Message[] = [] - if (response.ok) { - const json = await response.json() - console.log('json', json) - const output = json.candidates[0].output || json.statusText || json.status || json - return output as string - } - - const text = await response.text() - throw new Error(`Failed to fetch chat completion: ${text}`) - } + let maxTokens = payload.globalSettings.maxTokens as number + let messageHistorySize = payload.globalSettings.messageHistorySize as number + + // Iterate through the message history + while (messageHistorySize > 0) { + messageHistorySize-- + // Get the last message from the payload + const m = payload.messages.pop() + if (m === undefined) + break + + if (maxTokens - m.content.length < 0) + break + + maxTokens -= m.content.length + messages.unshift(m) + } + + const response = await fetchChatCompletion({ + apiKey: payload.globalSettings.apiKey as string, + body: { + contents: messages.map((m) => ({ role: m.role, parts: [{ text: m.content }] })), + }, + model: payload.globalSettings.model as string, + }) + + if (response.ok) { + const json = await response.json() + // console.log('json', json) + const output = json.candidates[0].content.parts[0].text || json + return output as string + } + + const text = await response.text() + throw new Error(`Failed to fetch chat completion: ${text}`) +} diff --git a/src/providers/google/index.ts b/src/providers/google/index.ts index 26d2297d..6ad0e9c3 100644 --- a/src/providers/google/index.ts +++ b/src/providers/google/index.ts @@ -15,6 +15,16 @@ const providerGoogle = () => { name: 'API Key', type: 'api-key', }, + { + key: 'model', + name: 'Google model', + description: 'Custom model for Google API.', + type: 'select', + options: [ + { value: 'gemini-pro', label: 'gemini-pro' }, + ], + default: 'gemini-pro', + }, { key: 'maxTokens', name: 'Max Tokens', diff --git a/src/types/message.ts b/src/types/message.ts index 5c6f119c..711b37aa 100644 --- a/src/types/message.ts +++ b/src/types/message.ts @@ -1,5 +1,5 @@ export interface Message { - role: 'system' | 'user' | 'assistant' + role: 'system' | 'user' | 'assistant' | 'model' content: string } From 8b63d3c04efbe1b534c1b189659a37ff86498e76 Mon Sep 17 00:00:00 2001 From: ddiu8081 Date: Tue, 19 Dec 2023 03:37:00 +0800 Subject: [PATCH 3/4] chore: code style --- src/providers/google/api.ts | 22 +++++++++++----------- src/providers/google/handler.ts | 8 ++++---- src/providers/google/index.ts | 4 ++-- src/stores/provider.ts | 1 - 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/providers/google/api.ts b/src/providers/google/api.ts index bbde24b3..723f47e0 100644 --- a/src/providers/google/api.ts +++ b/src/providers/google/api.ts @@ -1,15 +1,15 @@ export interface GoogleFetchPayload { - apiKey: string - body: Record - model?: string + apiKey: string + body: Record + model?: string } export const fetchChatCompletion = async(payload: GoogleFetchPayload) => { - const { apiKey, body, model } = payload || {} - const initOptions = { - headers: { 'Content-Type': 'application/json' }, - method: 'POST', - body: JSON.stringify({ ...body }), - } - return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions); -} \ No newline at end of file + const { apiKey, body, model } = payload || {} + const initOptions = { + headers: { 'Content-Type': 'application/json' }, + method: 'POST', + body: JSON.stringify({ ...body }), + } + return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions) +} diff --git a/src/providers/google/handler.ts b/src/providers/google/handler.ts index 66ba09fa..810bfc85 100644 --- a/src/providers/google/handler.ts +++ b/src/providers/google/handler.ts @@ -1,15 +1,15 @@ -import { fetchChatCompletion } from "./api" +import { fetchChatCompletion } from './api' import type { Message } from '@/types/message' import type { HandlerPayload, Provider } from '@/types/provider' -export const handlePrompt: Provider['handlePrompt'] = async (payload, signal?: AbortSignal) => { +export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { if (payload.botId === 'chat_continuous') return handleChatCompletion(payload, signal) if (payload.botId === 'chat_single') return handleChatCompletion(payload, signal) } -export const handleRapidPrompt: Provider['handleRapidPrompt'] = async (prompt, globalSettings) => { +export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { const rapidPromptPayload = { conversationId: 'temp', conversationType: 'chat_single', @@ -53,7 +53,7 @@ export const handleChatCompletion = async(payload: HandlerPayload, signal?: Abor const response = await fetchChatCompletion({ apiKey: payload.globalSettings.apiKey as string, body: { - contents: messages.map((m) => ({ role: m.role, parts: [{ text: m.content }] })), + contents: messages.map(m => ({ role: m.role, parts: [{ text: m.content }] })), }, model: payload.globalSettings.model as string, }) diff --git a/src/providers/google/index.ts b/src/providers/google/index.ts index 6ad0e9c3..4023b2ea 100644 --- a/src/providers/google/index.ts +++ b/src/providers/google/index.ts @@ -63,8 +63,8 @@ const providerGoogle = () => { ], handlePrompt, handleRapidPrompt, - } + } return provider } -export default providerGoogle \ No newline at end of file +export default providerGoogle diff --git a/src/stores/provider.ts b/src/stores/provider.ts index 5b64a9d2..e45e230f 100644 --- a/src/stores/provider.ts +++ b/src/stores/provider.ts @@ -5,7 +5,6 @@ import providerReplicate from '@/providers/replicate' import { allConversationTypes } from '@/types/conversation' import type { BotMeta } from '@/types/app' - export const providerList = [ providerOpenAI(), providerAzure(), From 29f45214981962d8b055f806907e870a9bbdf679 Mon Sep 17 00:00:00 2001 From: ddiu8081 Date: Tue, 19 Dec 2023 04:02:06 +0800 Subject: [PATCH 4/4] refactor: move Gemini message parser to provider --- src/logics/conversation.ts | 6 ++---- src/providers/google/handler.ts | 3 ++- src/providers/google/parser.ts | 35 +++++++++++++++++++++++++++++++++ src/types/message.ts | 2 +- 4 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 src/providers/google/parser.ts diff --git a/src/logics/conversation.ts b/src/logics/conversation.ts index c5fd94d9..57d0e35f 100644 --- a/src/logics/conversation.ts +++ b/src/logics/conversation.ts @@ -34,7 +34,6 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s setLoadingStateByConversationId(conversation.id, true) let providerResponse: PromptResponse - const systemRole = provider.name === 'Google' ? 'user' : 'system' const handlerPayload: HandlerPayload = { conversationId: conversation.id, conversationType: bot.type, @@ -43,8 +42,7 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s botSettings: {}, prompt, messages: [ - ...(conversation.systemInfo ? [{ role: systemRole, content: conversation.systemInfo }] : []) as Message[], - ...(provider.name === 'Google' && conversation.systemInfo ? [{ role: 'model', content: 'OK' }] : []) as Message[], // Google Gemini API currently only support odd number of messages. + ...(conversation.systemInfo ? [{ role: 'system', content: conversation.systemInfo }] : []) as Message[], ...(destr(conversation.mockMessages) || []) as Message[], ...getMessagesByConversationId(conversation.id).map(message => ({ role: message.role, @@ -79,7 +77,7 @@ export const handlePrompt = async(conversation: Conversation, prompt?: string, s } pushMessageByConversationId(conversation.id, { id: messageId, - role: provider.name === 'Google' ? 'model' : 'assistant', + role: 'assistant', content: typeof providerResponse === 'string' ? providerResponse : '', stream: providerResponse instanceof ReadableStream, dateTime: new Date().getTime(), diff --git a/src/providers/google/handler.ts b/src/providers/google/handler.ts index 810bfc85..8b64c239 100644 --- a/src/providers/google/handler.ts +++ b/src/providers/google/handler.ts @@ -1,4 +1,5 @@ import { fetchChatCompletion } from './api' +import { parseMessageList } from './parser' import type { Message } from '@/types/message' import type { HandlerPayload, Provider } from '@/types/provider' @@ -53,7 +54,7 @@ export const handleChatCompletion = async(payload: HandlerPayload, signal?: Abor const response = await fetchChatCompletion({ apiKey: payload.globalSettings.apiKey as string, body: { - contents: messages.map(m => ({ role: m.role, parts: [{ text: m.content }] })), + contents: parseMessageList(messages), }, model: payload.globalSettings.model as string, }) diff --git a/src/providers/google/parser.ts b/src/providers/google/parser.ts new file mode 100644 index 00000000..0af47037 --- /dev/null +++ b/src/providers/google/parser.ts @@ -0,0 +1,35 @@ +import type { Message } from '@/types/message' + +export const parseMessageList = (rawList: Message[]) => { + interface GoogleGeminiMessage { + role: 'user' | 'model' + // TODO: Add support for image input + parts: [ + { text: string }, + ] + } + + if (rawList.length === 0) + return [] + + const parsedList: GoogleGeminiMessage[] = [] + // if first message is system message, insert an empty message after it + if (rawList[0].role === 'system') { + parsedList.push({ role: 'user', parts: [{ text: rawList[0].content }] }) + parsedList.push({ role: 'model', parts: [{ text: 'OK.' }] }) + } + // covert other messages + const roleDict = { + user: 'user', + assistant: 'model', + } as const + for (const message of rawList) { + if (message.role === 'system') + continue + parsedList.push({ + role: roleDict[message.role], + parts: [{ text: message.content }], + }) + } + return parsedList +} diff --git a/src/types/message.ts b/src/types/message.ts index 711b37aa..5c6f119c 100644 --- a/src/types/message.ts +++ b/src/types/message.ts @@ -1,5 +1,5 @@ export interface Message { - role: 'system' | 'user' | 'assistant' | 'model' + role: 'system' | 'user' | 'assistant' content: string }