From 908ad6578cfcc6e38fe1c67a3e218e3dce0c3eee Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 9 Aug 2024 16:07:17 +0100 Subject: [PATCH] fix(ai-proxy): cloud identity (sdk) now used in ai transformer plugins --- ...oxy-cloud-identity-transformer-plugins.yml | 5 ++ kong/llm/drivers/azure.lua | 2 +- kong/llm/drivers/bedrock.lua | 51 ++++++++++++++----- kong/llm/drivers/gemini.lua | 21 +++++++- kong/llm/drivers/shared.lua | 3 -- kong/llm/init.lua | 10 ++-- .../ai-request-transformer/handler.lua | 19 ++++++- .../ai-response-transformer/handler.lua | 19 ++++++- 8 files changed, 106 insertions(+), 24 deletions(-) create mode 100644 changelog/unreleased/kong/ai-proxy-cloud-identity-transformer-plugins.yml diff --git a/changelog/unreleased/kong/ai-proxy-cloud-identity-transformer-plugins.yml b/changelog/unreleased/kong/ai-proxy-cloud-identity-transformer-plugins.yml new file mode 100644 index 000000000000..1058206319a2 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-cloud-identity-transformer-plugins.yml @@ -0,0 +1,5 @@ +message: | + **AI-Transformer-Plugins**: Fixed a bug where cloud identity authentication + was not used in `ai-request-transformer` and `ai-response-transformer` plugins. +scope: Plugin +type: bugfix diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 343904ffad24..8fe55b2faed6 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -40,7 +40,7 @@ function _M.post_request(conf) end end -function _M.subrequest(body, conf, http_opts, return_res_table) +function _M.subrequest(body, conf, http_opts, return_res_table, identity_interface) local body_string, err if type(body) == "table" then diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 5f7ddce5119c..26631107c6ae 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -306,7 +306,7 @@ function _M.to_format(request_table, model_info, route_type) return response_object, content_type, nil end -function _M.subrequest(body, conf, http_opts, return_res_table) +function _M.subrequest(body, conf, http_opts, return_res_table, identity_interface) -- use shared/standard subrequest routine local body_string, err @@ -322,25 +322,52 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end -- may be overridden - local url = (conf.model.options and conf.model.options.upstream_url) - or fmt( - "%s%s", - ai_shared.upstream_url_format[DRIVER_NAME], - ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - ) + local f_url = conf.model.options and conf.model.options.upstream_url + + if not f_url then -- upstream_url override is not set + local uri = fmt(ai_shared.upstream_url_format[DRIVER_NAME], identity_interface.interface.config.region) + local path = fmt( + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, + conf.model.name, + "converse") - local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + f_url = fmt("%s%s", uri, path) + end + + local parsed_url = socket_url.parse(f_url) + local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + + -- do the IAM auth and signature headers + identity_interface.interface.config.signatureVersion = "v4" + identity_interface.interface.config.endpointPrefix = "bedrock" + + local r = { + headers = {}, + method = method, + path = parsed_url.path, + host = parsed_url.host, + port = tonumber(parsed_url.port) or 443, + body = cjson.encode(body), + } + + local signature, err = signer(identity_interface.interface.config, r) + if not signature then + return nil, "failed to sign AWS request: " .. (err or "NONE") + end local headers = { ["Accept"] = "application/json", ["Content-Type"] = "application/json", } - - if conf.auth and conf.auth.header_name then - headers[conf.auth.header_name] = conf.auth.header_value + headers["Authorization"] = signature.headers["Authorization"] + if signature.headers["X-Amz-Security-Token"] then + headers["X-Amz-Security-Token"] = signature.headers["X-Amz-Security-Token"] + end + if signature.headers["X-Amz-Date"] then + headers["X-Amz-Date"] = signature.headers["X-Amz-Date"] end - local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + local res, err, httpc = ai_shared.http_request(f_url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index d386961997f4..abe71a74775a 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -262,7 +262,7 @@ function _M.to_format(request_table, model_info, route_type) return response_object, content_type, nil end -function _M.subrequest(body, conf, http_opts, return_res_table) +function _M.subrequest(body, conf, http_opts, return_res_table, identity_interface) -- use shared/standard subrequest routine local body_string, err @@ -292,7 +292,24 @@ function _M.subrequest(body, conf, http_opts, return_res_table) ["Content-Type"] = "application/json", } - if conf.auth and conf.auth.header_name then + if identity_interface and identity_interface.interface then + if identity_interface.interface:needsRefresh() then + -- HACK: A bug in lua-resty-gcp tries to re-load the environment + -- variable every time, which fails in nginx + -- Create a whole new interface instead. + -- Memory leaks are mega unlikely because this should only + -- happen about once an hour, and the old one will be + -- cleaned up anyway. + local service_account_json = identity_interface.interface.service_account_json + local identity_interface_new = identity_interface.interface:new(service_account_json) + identity_interface.interface.token = identity_interface_new.token + + kong.log.notice("gcp identity token for ", kong.plugin.get_id(), " has been refreshed") + end + + headers["Authorization"] = "Bearer " .. identity_interface.interface.token + + elseif conf.auth and conf.auth.header_name then headers[conf.auth.header_name] = conf.auth.header_value end diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b96734db272e..faa14ef3c2a1 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -70,9 +70,6 @@ local AWS = require("resty.aws") local AWS_REGION do AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") end - -local AZURE_TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default" -local AZURE_TOKEN_VERSION = "v2.0" ---- _M._CONST = { diff --git a/kong/llm/init.lua b/kong/llm/init.lua index b4b7bba5ae7a..302d74cf144d 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -95,9 +95,9 @@ do local ai_request -- mistral, cohere, titan (via Bedrock) don't support system commands - if self.driver == "bedrock" then + if self.conf.model.provider == "bedrock" then for _, p in ipairs(self.driver.bedrock_unsupported_system_role_patterns) do - if request.model:find(p) then + if self.conf.model.name:find(p) then ai_request = { messages = { [1] = { @@ -147,7 +147,7 @@ do ai_shared.pre_request(self.conf, ai_request) -- send it to the ai service - local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false) + local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface) if err then return nil, "failed to introspect request with AI service: " .. err end @@ -225,13 +225,15 @@ do --- Instantiate a new LLM driver instance. -- @tparam table conf Configuration table -- @tparam table http_opts HTTP options table + -- @tparam table [optional] cloud-authentication identity interface -- @treturn[1] table A new LLM driver instance -- @treturn[2] nil -- @treturn[2] string An error message if instantiation failed - function _M.new_driver(conf, http_opts) + function _M.new_driver(conf, http_opts, identity_interface) local self = { conf = conf or {}, http_opts = http_opts or {}, + identity_interface = identity_interface, -- 'or nil' } setmetatable(self, LLM) diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 1bad3a92db3d..dd4325183d45 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -5,11 +5,17 @@ local kong_meta = require "kong.meta" local fmt = string.format local llm = require("kong.llm") local llm_state = require("kong.llm.state") +local ai_shared = require("kong.llm.drivers.shared") -- _M.PRIORITY = 777 _M.VERSION = kong_meta.version +local _KEYBASTION = setmetatable({}, { + __mode = "k", + __index = ai_shared.cloud_identity_function, +}) + local function bad_request(msg) kong.log.info(msg) return kong.response.exit(400, { error = { message = msg } }) @@ -40,14 +46,25 @@ local function create_http_opts(conf) end function _M:access(conf) + local kong_ctx_shared = kong.ctx.shared + kong.service.request.enable_buffering() llm_state.should_disable_ai_proxy_response_transform() + -- get cloud identity SDK, if required + local identity_interface = _KEYBASTION[conf.llm] + + if identity_interface and identity_interface.error then + kong_ctx_shared.skip_response_transformer = true + kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error) + return kong.response.exit(500, "LLM request failed before proxying") + end + -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ - local ai_driver, err = llm.new_driver(conf.llm, http_opts) + local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface) if not ai_driver then return internal_server_error(err) diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index 815b64f351fa..872b8ea924f4 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -7,11 +7,17 @@ local fmt = string.format local kong_utils = require("kong.tools.gzip") local llm = require("kong.llm") local llm_state = require("kong.llm.state") +local ai_shared = require("kong.llm.drivers.shared") -- _M.PRIORITY = 769 _M.VERSION = kong_meta.version +local _KEYBASTION = setmetatable({}, { + __mode = "k", + __index = ai_shared.cloud_identity_function, +}) + local function bad_request(msg) kong.log.info(msg) return kong.response.exit(400, { error = { message = msg } }) @@ -99,14 +105,25 @@ end function _M:access(conf) + local kong_ctx_shared = kong.ctx.shared + kong.service.request.enable_buffering() llm_state.disable_ai_proxy_response_transform() + -- get cloud identity SDK, if required + local identity_interface = _KEYBASTION[conf.llm] + + if identity_interface and identity_interface.error then + kong_ctx_shared.skip_response_transformer = true + kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error) + return kong.response.exit(500, "LLM request failed before proxying") + end + -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ - local ai_driver, err = llm.new_driver(conf.llm, http_opts) + local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface) if not ai_driver then return internal_server_error(err)