diff --git a/changelog/unreleased/kong/ai-gemini-fix-transformer-plugins.yml b/changelog/unreleased/kong/ai-gemini-fix-transformer-plugins.yml new file mode 100644 index 000000000000..cb82f1c92133 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-fix-transformer-plugins.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed an issue where AI Transformer plugins always returned a 404 error when using 'Google One' Gemini subscriptions." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml b/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml new file mode 100644 index 000000000000..3fd09d0b0e7c --- /dev/null +++ b/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml @@ -0,0 +1,3 @@ +message: "**ai-transformers**: Fixed a bug where the correct LLM error message was not propagated to the caller." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 0de91c2f49a5..16f5b25c36f4 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -198,10 +198,19 @@ local function from_gemini_chat_openai(response, model_info, route_type) } end - else -- probably a server fault or other unexpected response + elseif response.candidates + and #response.candidates > 0 + and response.candidates[1].finishReason + and response.candidates[1].finishReason == "SAFETY" then + local err = "transformation generation candidate breached Gemini content safety" + ngx.log(ngx.ERR, err) + return nil, err + + else-- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) return nil, err + end return cjson.encode(messages) @@ -277,13 +286,34 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa return nil, nil, "body must be table or string" end - -- may be overridden - local url = (conf.model.options and conf.model.options.upstream_url) - or fmt( - "%s%s", - ai_shared.upstream_url_format[DRIVER_NAME], - ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - ) + local operation = llm_state.is_streaming_mode() and "streamGenerateContent" + or "generateContent" + local f_url = conf.model.options and conf.model.options.upstream_url + + if not f_url then -- upstream_url override is not set + -- check if this is "public" or "vertex" gemini deployment + if conf.model.options + and conf.model.options.gemini + and conf.model.options.gemini.api_endpoint + and conf.model.options.gemini.project_id + and conf.model.options.gemini.location_id + then + -- vertex mode + f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"], + conf.model.options.gemini.api_endpoint) .. + fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path, + conf.model.options.gemini.project_id, + conf.model.options.gemini.location_id, + conf.model.name, + operation) + else + -- public mode + f_url = ai_shared.upstream_url_format["gemini"] .. + fmt(ai_shared.operation_map["gemini"][conf.route_type].path, + conf.model.name, + operation) + end + end local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method @@ -312,7 +342,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa headers[conf.auth.header_name] = conf.auth.header_value end - local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + local res, err, httpc = ai_shared.http_request(f_url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 9577466a95f4..2afa28da2f0c 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -142,10 +142,10 @@ do if err then return nil, err end - + -- run the shared logging/analytics/auth function ai_shared.pre_request(self.conf, ai_request) - + -- send it to the ai service local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface) if err then @@ -153,7 +153,7 @@ do end -- parse and convert the response - local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) + local ai_response, err, _ = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) if err then return nil, "failed to convert AI response to Kong format: " .. err end diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 40989e352246..38c62dd3dda1 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -54,6 +54,29 @@ local OPENAI_FLAT_RESPONSE = { }, } +local GEMINI_GOOD = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + auth = { + header_name = "x-goog-api-key", + header_value = "123", + }, +} + local OPENAI_BAD_REQUEST = { route_type = "llm/v1/chat", model = { @@ -177,6 +200,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/failssafety" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json")) + } + } + location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" @@ -223,6 +255,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + local fails_safety = assert(bp.routes:insert { + paths = { "/echo-fails-safety" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = fails_safety.id }, + config = { + prompt = SYSTEM_PROMPT, + llm = GEMINI_GOOD, + }, + } + local internal_server_error = assert(bp.routes:insert { paths = { "/echo-internal-server-error" } }) @@ -327,6 +371,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) end) + it("fails Gemini content-safety", function() + local r = client:get("/echo-fails-safety", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*") + end) + it("internal server error from LLM", function() local r = client:get("/echo-internal-server-error", { headers = { diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 287acf2dc4c0..7b06d8531592 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -71,6 +71,29 @@ local OPENAI_FLAT_RESPONSE = { }, } +local GEMINI_GOOD = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + auth = { + header_name = "x-goog-api-key", + header_value = "123", + }, +} + local OPENAI_BAD_INSTRUCTIONS = { route_type = "llm/v1/chat", model = { @@ -250,6 +273,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/failssafety" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json")) + } + } + location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" @@ -338,6 +370,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + local fails_safety = assert(bp.routes:insert { + paths = { "/echo-fails-safety" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = fails_safety.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = false, + llm = GEMINI_GOOD, + }, + } + local internal_server_error = assert(bp.routes:insert { paths = { "/echo-internal-server-error" } }) @@ -485,6 +530,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) end) + it("fails Gemini content-safety", function() + local r = client:get("/echo-fails-safety", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*") + end) + it("internal server error from LLM", function() local r = client:get("/echo-internal-server-error", { headers = { diff --git a/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json new file mode 100644 index 000000000000..aa4f2d9e5ba6 --- /dev/null +++ b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json @@ -0,0 +1,30 @@ +{ + "candidates": [ + { + "finishReason": "SAFETY", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "MEDIUM" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 319, + "totalTokenCount": 319 + } +}