-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ai-proxy): google-gemini support
- Loading branch information
Showing
8 changed files
with
447 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
message: | | ||
Kong AI Gateway (AI Proxy and associated plugin family) now supports | ||
the Google Gemini "chat" (generateContent) interface. | ||
type: feature | ||
scope: Plugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
local _M = {} | ||
|
||
-- imports | ||
local cjson = require("cjson.safe") | ||
local fmt = string.format | ||
local ai_shared = require("kong.llm.drivers.shared") | ||
local socket_url = require("socket.url") | ||
local string_gsub = string.gsub | ||
local buffer = require("string.buffer") | ||
local table_insert = table.insert | ||
local string_lower = string.lower | ||
-- | ||
|
||
-- globals | ||
local DRIVER_NAME = "gemini" | ||
-- | ||
|
||
local _OPENAI_ROLE_MAPPING = { | ||
["system"] = "system", | ||
["user"] = "user", | ||
["assistant"] = "model", | ||
} | ||
|
||
local function to_bard_generation_config(request_table) | ||
return { | ||
["maxOutputTokens"] = request_table.max_tokens, | ||
["stopSequences"] = request_table.stop, | ||
["temperature"] = request_table.temperature, | ||
["topK"] = request_table.top_k, | ||
["topP"] = request_table.top_p, | ||
} | ||
end | ||
|
||
local function to_bard_chat_openai(request_table, model_info, route_type) | ||
if request_table then -- try-catch type mechanism | ||
local new_r = {} | ||
|
||
if request_table.messages and #request_table.messages > 0 then | ||
local system_prompt | ||
|
||
for i, v in ipairs(request_table.messages) do | ||
|
||
-- for 'system', we just concat them all into one Gemini instruction | ||
if v.role and v.role == "system" then | ||
system_prompt = system_prompt or buffer.new() | ||
system_prompt:put(v.content or "") | ||
else | ||
-- for any other role, just construct the chat history as 'parts.text' type | ||
new_r.contents = new_r.contents or {} | ||
table_insert(new_r.contents, { | ||
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' | ||
parts = { | ||
{ | ||
text = v.content or "" | ||
}, | ||
}, | ||
}) | ||
end | ||
end | ||
|
||
---- TODO for some reason this is broken? | ||
---- I think it's something to do with which "regional" endpoint of Gemini you hit... | ||
-- if system_prompt then | ||
-- new_r.systemInstruction = { | ||
-- parts = { | ||
-- { | ||
-- text = system_prompt:get(), | ||
-- }, | ||
-- }, | ||
-- } | ||
-- end | ||
---- | ||
|
||
end | ||
|
||
new_r.generationConfig = to_bard_generation_config(request_table) | ||
|
||
kong.log.debug(cjson.encode(new_r)) | ||
|
||
return new_r, "application/json", nil | ||
end | ||
|
||
local err = "empty request table received for transformation" | ||
ngx.log(ngx.ERR, err) | ||
return nil, nil, err | ||
end | ||
|
||
local function from_bard_chat_openai(response, model_info, route_type) | ||
local response, err = cjson.decode(response) | ||
|
||
if err then | ||
local err_client = "failed to decode response from Gemini" | ||
ngx.log(ngx.ERR, fmt("%s: %s", err_client, err)) | ||
return nil, err_client | ||
end | ||
|
||
-- messages/choices table is only 1 size, so don't need to static allocate | ||
local messages = {} | ||
messages.choices = {} | ||
|
||
if response.candidates | ||
and #response.candidates > 0 | ||
and response.candidates[1].content | ||
and response.candidates[1].content.parts | ||
and #response.candidates[1].content.parts > 0 | ||
and response.candidates[1].content.parts[1].text then | ||
|
||
messages.choices[1] = { | ||
index = 0, | ||
message = { | ||
role = "assistant", | ||
content = response.candidates[1].content.parts[1].text, | ||
}, | ||
finish_reason = string_lower(response.candidates[1].finishReason), | ||
} | ||
messages.object = "chat.completion" | ||
messages.model = model_info.name | ||
|
||
else -- probably a server fault or other unexpected response | ||
local err = "no generation candidates received from Gemini, or max_tokens too short" | ||
ngx.log(ngx.ERR, err) | ||
return nil, err | ||
end | ||
|
||
return cjson.encode(messages) | ||
end | ||
|
||
local function to_bard_chat_bard(request_table, model_info, route_type) | ||
return nil, nil, "bard to bard not yet implemented" | ||
end | ||
|
||
local function from_bard_chat_bard(request_table, model_info, route_type) | ||
return nil, nil, "bard to bard not yet implemented" | ||
end | ||
|
||
local transformers_to = { | ||
["llm/v1/chat"] = to_bard_chat_openai, | ||
["gemini/v1/chat"] = to_gemini_chat_bard, | ||
} | ||
|
||
local transformers_from = { | ||
["llm/v1/chat"] = from_bard_chat_openai, | ||
["gemini/v1/chat"] = from_gemini_chat_bard, | ||
} | ||
|
||
function _M.from_format(response_string, model_info, route_type) | ||
ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") | ||
|
||
-- MUST return a string, to set as the response body | ||
if not transformers_from[route_type] then | ||
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) | ||
end | ||
|
||
local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info, route_type) | ||
if not ok or err then | ||
return nil, fmt("transformation failed from type %s://%s: %s", | ||
model_info.provider, | ||
route_type, | ||
err or "unexpected_error" | ||
) | ||
end | ||
|
||
return response_string, nil | ||
end | ||
|
||
function _M.to_format(request_table, model_info, route_type) | ||
ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) | ||
|
||
if route_type == "preserve" then | ||
-- do nothing | ||
return request_table, nil, nil | ||
end | ||
|
||
if not transformers_to[route_type] then | ||
return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) | ||
end | ||
|
||
request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) | ||
|
||
local ok, response_object, content_type, err = pcall( | ||
transformers_to[route_type], | ||
request_table, | ||
model_info | ||
) | ||
if err or (not ok) then | ||
return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) | ||
end | ||
|
||
return response_object, content_type, nil | ||
end | ||
|
||
function _M.subrequest(body, conf, http_opts, return_res_table) | ||
-- use shared/standard subrequest routine | ||
local body_string, err | ||
|
||
if type(body) == "table" then | ||
body_string, err = cjson.encode(body) | ||
if err then | ||
return nil, nil, "failed to parse body to json: " .. err | ||
end | ||
elseif type(body) == "string" then | ||
body_string = body | ||
else | ||
return nil, nil, "body must be table or string" | ||
end | ||
|
||
-- may be overridden | ||
local url = (conf.model.options and conf.model.options.upstream_url) | ||
or fmt( | ||
"%s%s", | ||
ai_shared.upstream_url_format[DRIVER_NAME], | ||
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path | ||
) | ||
|
||
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method | ||
|
||
local headers = { | ||
["Accept"] = "application/json", | ||
["Content-Type"] = "application/json", | ||
} | ||
|
||
if conf.auth and conf.auth.header_name then | ||
headers[conf.auth.header_name] = conf.auth.header_value | ||
end | ||
|
||
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) | ||
if err then | ||
return nil, nil, "request to ai service failed: " .. err | ||
end | ||
|
||
if return_res_table then | ||
return res, res.status, nil, httpc | ||
else | ||
-- At this point, the entire request / response is complete and the connection | ||
-- will be closed or back on the connection pool. | ||
local status = res.status | ||
local body = res.body | ||
|
||
if status > 299 then | ||
return body, res.status, "status code " .. status | ||
end | ||
|
||
return body, res.status, nil | ||
end | ||
end | ||
|
||
function _M.header_filter_hooks(body) | ||
-- nothing to parse in header_filter phase | ||
end | ||
|
||
function _M.post_request(conf) | ||
if ai_shared.clear_response_headers[DRIVER_NAME] then | ||
for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do | ||
kong.response.clear_header(v) | ||
end | ||
end | ||
end | ||
|
||
function _M.pre_request(conf, body) | ||
kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli | ||
|
||
return true, nil | ||
end | ||
|
||
-- returns err or nil | ||
function _M.configure_request(conf) | ||
local parsed_url | ||
|
||
if (conf.model.options and conf.model.options.upstream_url) then | ||
parsed_url = socket_url.parse(conf.model.options.upstream_url) | ||
else | ||
local path = conf.model.options | ||
and conf.model.options.upstream_path | ||
or ai_shared.operation_map[DRIVER_NAME][conf.route_type] | ||
and fmt(ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, conf.model.name) | ||
or "/" | ||
if not path then | ||
return nil, fmt("operation %s is not supported for openai provider", conf.route_type) | ||
end | ||
|
||
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) | ||
parsed_url.path = path | ||
end | ||
|
||
-- if the path is read from a URL capture, ensure that it is valid | ||
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") | ||
|
||
kong.service.request.set_path(parsed_url.path) | ||
kong.service.request.set_scheme(parsed_url.scheme) | ||
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) | ||
|
||
local auth_header_name = conf.auth and conf.auth.header_name | ||
local auth_header_value = conf.auth and conf.auth.header_value | ||
local auth_param_name = conf.auth and conf.auth.param_name | ||
local auth_param_value = conf.auth and conf.auth.param_value | ||
local auth_param_location = conf.auth and conf.auth.param_location | ||
|
||
if auth_header_name and auth_header_value then | ||
kong.service.request.set_header(auth_header_name, auth_header_value) | ||
end | ||
|
||
if auth_param_name and auth_param_value and auth_param_location == "query" then | ||
local query_table = kong.request.get_query() | ||
query_table[auth_param_name] = auth_param_value | ||
kong.service.request.set_query(query_table) | ||
end | ||
|
||
-- if auth_param_location is "form", it will have already been set in a global pre-request hook | ||
return true, nil | ||
end | ||
|
||
return _M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.