-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #115 from LyuLumos/main
feat: Add support for Google Gemini
Showing
5 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
export interface GoogleFetchPayload { | ||
apiKey: string | ||
body: Record<string, any> | ||
model?: string | ||
} | ||
|
||
export const fetchChatCompletion = async(payload: GoogleFetchPayload) => { | ||
const { apiKey, body, model } = payload || {} | ||
const initOptions = { | ||
headers: { 'Content-Type': 'application/json' }, | ||
method: 'POST', | ||
body: JSON.stringify({ ...body }), | ||
} | ||
return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import { fetchChatCompletion } from './api' | ||
import { parseMessageList } from './parser' | ||
import type { Message } from '@/types/message' | ||
import type { HandlerPayload, Provider } from '@/types/provider' | ||
|
||
export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { | ||
if (payload.botId === 'chat_continuous') | ||
return handleChatCompletion(payload, signal) | ||
if (payload.botId === 'chat_single') | ||
return handleChatCompletion(payload, signal) | ||
} | ||
|
||
export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { | ||
const rapidPromptPayload = { | ||
conversationId: 'temp', | ||
conversationType: 'chat_single', | ||
botId: 'temp', | ||
globalSettings: { | ||
...globalSettings, | ||
model: 'gemini-pro', | ||
}, | ||
botSettings: {}, | ||
prompt, | ||
messages: { contents: [{ role: 'user', parts: [{ text: prompt }] }] }, | ||
} as unknown as HandlerPayload | ||
const result = await handleChatCompletion(rapidPromptPayload) | ||
if (typeof result === 'string') | ||
return result | ||
return '' | ||
} | ||
|
||
export const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => { | ||
// An array to store the chat messages | ||
const messages: Message[] = [] | ||
|
||
let maxTokens = payload.globalSettings.maxTokens as number | ||
let messageHistorySize = payload.globalSettings.messageHistorySize as number | ||
|
||
// Iterate through the message history | ||
while (messageHistorySize > 0) { | ||
messageHistorySize-- | ||
// Get the last message from the payload | ||
const m = payload.messages.pop() | ||
if (m === undefined) | ||
break | ||
|
||
if (maxTokens - m.content.length < 0) | ||
break | ||
|
||
maxTokens -= m.content.length | ||
messages.unshift(m) | ||
} | ||
|
||
const response = await fetchChatCompletion({ | ||
apiKey: payload.globalSettings.apiKey as string, | ||
body: { | ||
contents: parseMessageList(messages), | ||
}, | ||
model: payload.globalSettings.model as string, | ||
}) | ||
|
||
if (response.ok) { | ||
const json = await response.json() | ||
// console.log('json', json) | ||
const output = json.candidates[0].content.parts[0].text || json | ||
return output as string | ||
} | ||
|
||
const text = await response.text() | ||
throw new Error(`Failed to fetch chat completion: ${text}`) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import { | ||
handlePrompt, | ||
handleRapidPrompt, | ||
} from './handler' | ||
import type { Provider } from '@/types/provider' | ||
|
||
const providerGoogle = () => { | ||
const provider: Provider = { | ||
id: 'provider-google', | ||
icon: 'i-simple-icons-google', // @unocss-include | ||
name: 'Google', | ||
globalSettings: [ | ||
{ | ||
key: 'apiKey', | ||
name: 'API Key', | ||
type: 'api-key', | ||
}, | ||
{ | ||
key: 'model', | ||
name: 'Google model', | ||
description: 'Custom model for Google API.', | ||
type: 'select', | ||
options: [ | ||
{ value: 'gemini-pro', label: 'gemini-pro' }, | ||
], | ||
default: 'gemini-pro', | ||
}, | ||
{ | ||
key: 'maxTokens', | ||
name: 'Max Tokens', | ||
description: 'The maximum number of tokens to generate in the completion.', | ||
type: 'slider', | ||
min: 0, | ||
max: 32768, | ||
default: 2048, | ||
step: 1, | ||
}, | ||
{ | ||
key: 'messageHistorySize', | ||
name: 'Max History Message Size', | ||
description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.', | ||
type: 'slider', | ||
min: 1, | ||
max: 24, | ||
default: 5, | ||
step: 1, | ||
}, | ||
], | ||
bots: [ | ||
{ | ||
id: 'chat_continuous', | ||
type: 'chat_continuous', | ||
name: 'Continuous Chat', | ||
settings: [], | ||
}, | ||
{ | ||
id: 'chat_single', | ||
type: 'chat_single', | ||
name: 'Single Chat', | ||
settings: [], | ||
}, | ||
|
||
], | ||
handlePrompt, | ||
handleRapidPrompt, | ||
} | ||
return provider | ||
} | ||
|
||
export default providerGoogle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import type { Message } from '@/types/message' | ||
|
||
export const parseMessageList = (rawList: Message[]) => { | ||
interface GoogleGeminiMessage { | ||
role: 'user' | 'model' | ||
// TODO: Add support for image input | ||
parts: [ | ||
{ text: string }, | ||
] | ||
} | ||
|
||
if (rawList.length === 0) | ||
return [] | ||
|
||
const parsedList: GoogleGeminiMessage[] = [] | ||
// if first message is system message, insert an empty message after it | ||
if (rawList[0].role === 'system') { | ||
parsedList.push({ role: 'user', parts: [{ text: rawList[0].content }] }) | ||
parsedList.push({ role: 'model', parts: [{ text: 'OK.' }] }) | ||
} | ||
// covert other messages | ||
const roleDict = { | ||
user: 'user', | ||
assistant: 'model', | ||
} as const | ||
for (const message of rawList) { | ||
if (message.role === 'system') | ||
continue | ||
parsedList.push({ | ||
role: roleDict[message.role], | ||
parts: [{ text: message.content }], | ||
}) | ||
} | ||
return parsedList | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters