From 063c199e5848b625b6aac0c6c6d0e7439f2b68b2 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 23 Sep 2024 15:50:43 +0100 Subject: [PATCH] fix(ai-transformers): incorrect return parameter used for parser error handling --- .../ai-transformers-bad-error-handling.yml | 3 + kong/llm/drivers/gemini.lua | 11 +++- kong/llm/init.lua | 6 +- .../02-integration_spec.lua | 60 ++++++++++++++++++ .../02-integration_spec.lua | 61 +++++++++++++++++++ 5 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 changelog/unreleased/kong/ai-transformers-bad-error-handling.yml diff --git a/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml b/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml new file mode 100644 index 000000000000..3fd09d0b0e7c --- /dev/null +++ b/changelog/unreleased/kong/ai-transformers-bad-error-handling.yml @@ -0,0 +1,3 @@ +message: "**ai-transformers**: Fixed a bug where the correct LLM error message was not propagated to the caller." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 037a039323ca..16f5b25c36f4 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -198,10 +198,19 @@ local function from_gemini_chat_openai(response, model_info, route_type) } end - else -- probably a server fault or other unexpected response + elseif response.candidates + and #response.candidates > 0 + and response.candidates[1].finishReason + and response.candidates[1].finishReason == "SAFETY" then + local err = "transformation generation candidate breached Gemini content safety" + ngx.log(ngx.ERR, err) + return nil, err + + else-- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) return nil, err + end return cjson.encode(messages) diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 9577466a95f4..2afa28da2f0c 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -142,10 +142,10 @@ do if err then return nil, err end - + -- run the shared logging/analytics/auth function ai_shared.pre_request(self.conf, ai_request) - + -- send it to the ai service local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface) if err then @@ -153,7 +153,7 @@ do end -- parse and convert the response - local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) + local ai_response, err, _ = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) if err then return nil, "failed to convert AI response to Kong format: " .. err end diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 40989e352246..38c62dd3dda1 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -54,6 +54,29 @@ local OPENAI_FLAT_RESPONSE = { }, } +local GEMINI_GOOD = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + auth = { + header_name = "x-goog-api-key", + header_value = "123", + }, +} + local OPENAI_BAD_REQUEST = { route_type = "llm/v1/chat", model = { @@ -177,6 +200,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/failssafety" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json")) + } + } + location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" @@ -223,6 +255,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + local fails_safety = assert(bp.routes:insert { + paths = { "/echo-fails-safety" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = fails_safety.id }, + config = { + prompt = SYSTEM_PROMPT, + llm = GEMINI_GOOD, + }, + } + local internal_server_error = assert(bp.routes:insert { paths = { "/echo-internal-server-error" } }) @@ -327,6 +371,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) end) + it("fails Gemini content-safety", function() + local r = client:get("/echo-fails-safety", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*") + end) + it("internal server error from LLM", function() local r = client:get("/echo-internal-server-error", { headers = { diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 287acf2dc4c0..7b06d8531592 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -71,6 +71,29 @@ local OPENAI_FLAT_RESPONSE = { }, } +local GEMINI_GOOD = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + auth = { + header_name = "x-goog-api-key", + header_value = "123", + }, +} + local OPENAI_BAD_INSTRUCTIONS = { route_type = "llm/v1/chat", model = { @@ -250,6 +273,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/failssafety" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json")) + } + } + location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" @@ -338,6 +370,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + local fails_safety = assert(bp.routes:insert { + paths = { "/echo-fails-safety" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = fails_safety.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = false, + llm = GEMINI_GOOD, + }, + } + local internal_server_error = assert(bp.routes:insert { paths = { "/echo-internal-server-error" } }) @@ -485,6 +530,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) end) + it("fails Gemini content-safety", function() + local r = client:get("/echo-fails-safety", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*") + end) + it("internal server error from LLM", function() local r = client:get("/echo-internal-server-error", { headers = {