Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: [AG-113] Gemini transformer plugins, incorrect URL and response error parsers #13703

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed an issue where AI Transformer plugins always returned a 404 error when using 'Google One' Gemini subscriptions."
type: bugfix
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-transformers**: Fixed a bug where the correct LLM error message was not propagated to the caller."
type: bugfix
scope: Plugin
48 changes: 39 additions & 9 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,19 @@ local function from_gemini_chat_openai(response, model_info, route_type)
}
end

else -- probably a server fault or other unexpected response
elseif response.candidates
and #response.candidates > 0
and response.candidates[1].finishReason
and response.candidates[1].finishReason == "SAFETY" then
local err = "transformation generation candidate breached Gemini content safety"
ngx.log(ngx.ERR, err)
return nil, err

else-- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
return nil, err

end

return cjson.encode(messages)
Expand Down Expand Up @@ -277,13 +286,34 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
return nil, nil, "body must be table or string"
end

-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
)
local operation = llm_state.is_streaming_mode() and "streamGenerateContent"
or "generateContent"
local f_url = conf.model.options and conf.model.options.upstream_url

if not f_url then -- upstream_url override is not set
-- check if this is "public" or "vertex" gemini deployment
if conf.model.options
and conf.model.options.gemini
and conf.model.options.gemini.api_endpoint
and conf.model.options.gemini.project_id
and conf.model.options.gemini.location_id
then
-- vertex mode
f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"],
conf.model.options.gemini.api_endpoint) ..
fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path,
conf.model.options.gemini.project_id,
conf.model.options.gemini.location_id,
conf.model.name,
operation)
else
-- public mode
f_url = ai_shared.upstream_url_format["gemini"] ..
fmt(ai_shared.operation_map["gemini"][conf.route_type].path,
conf.model.name,
operation)
end
end

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method

Expand Down Expand Up @@ -312,7 +342,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
local res, err, httpc = ai_shared.http_request(f_url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end
Expand Down
6 changes: 3 additions & 3 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,18 @@ do
if err then
return nil, err
end

-- run the shared logging/analytics/auth function
ai_shared.pre_request(self.conf, ai_request)

-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface)
if err then
return nil, "failed to introspect request with AI service: " .. err
end

-- parse and convert the response
local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
local ai_response, err, _ = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
if err then
return nil, "failed to convert AI response to Kong format: " .. err
end
Expand Down
60 changes: 60 additions & 0 deletions spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_REQUEST = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -177,6 +200,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}

location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"

ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}

location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -223,6 +255,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -327,6 +371,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down
61 changes: 61 additions & 0 deletions spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_INSTRUCTIONS = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -250,6 +273,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}

location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"

ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}

location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -338,6 +370,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
parse_llm_response_json_instructions = false,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -485,6 +530,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down
Loading