Skip to content

Commit

Permalink
feat(llm): add auth.allow_override option for llm auth functionality (#…
Browse files Browse the repository at this point in the history
…13493)

Add `allow_override` option to allow overriding the upstream model auth
parameter or header from the caller's request.
  • Loading branch information
oowl authored and fffonion committed Aug 15, 2024
1 parent 7ff5b94 commit 34b6ca8
Show file tree
Hide file tree
Showing 18 changed files with 1,089 additions and 18 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-proxy-plugin**: Add `allow_auth_override` option to allow overriding the upstream model auth parameter or header from the caller's request.
scope: Plugin
type: feature
1 change: 1 addition & 0 deletions kong/clustering/compat/removed_fields.lua
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ return {
"model.options.bedrock",
"auth.aws_access_key_id",
"auth.aws_secret_access_key",
"auth.allow_auth_override",
"model_name_header",
},
ai_prompt_decorator = {
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
9 changes: 7 additions & 2 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ function _M.configure_request(conf, identity_interface)

else
if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
query_table[auth_param_name] = auth_param_value
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
end
end
-- if auth_param_location is "form", it will have already been set in a pre-request hook
end
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a global pre-request hook
Expand Down
4 changes: 3 additions & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,9 @@ function _M.pre_request(conf, request_table)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then
request_table[auth_param_name] = auth_param_value
if request_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
request_table[auth_param_name] = auth_param_value
end
end

-- retrieve the plugin name
Expand Down
10 changes: 10 additions & 0 deletions kong/llm/schemas/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ local auth_schema = {
required = false,
encrypted = true,
referenceable = true }},
{ allow_auth_override = {
type = "boolean",
description = "If enabled, the authorization header or parameter can be overridden in the request by the value configured in the plugin.",
required = false,
default = true }},
}
}

Expand Down Expand Up @@ -285,6 +290,11 @@ return {
{ logging = logging_schema },
},
entity_checks = {
{ conditional = { if_field = "model.provider",
if_match = { one_of = { "bedrock", "gemini" } },
then_field = "auth.allow_auth_override",
then_match = { eq = false },
then_err = "bedrock and gemini only support auth.allow_auth_override = false" }},
{ mutually_required = { "auth.header_name", "auth.header_value" }, },
{ mutually_required = { "auth.param_name", "auth.param_value", "auth.param_location" }, },

Expand Down
5 changes: 5 additions & 0 deletions spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -613,6 +614,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- gemini fields
expected.config.auth.gcp_service_account_json = nil
expected.config.auth.gcp_use_service_account = nil
expected.config.auth.allow_auth_override = nil
expected.config.model.options.gemini = nil

-- bedrock fields
Expand Down Expand Up @@ -653,6 +655,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -720,6 +723,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -819,6 +823,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- bedrock fields
expected.config.auth.aws_access_key_id = nil
expected.config.auth.aws_secret_access_key = nil
expected.config.auth.allow_auth_override = nil
expected.config.model.options.bedrock = nil

do_assert(uuid(), "3.7.0", expected)
Expand Down
55 changes: 55 additions & 0 deletions spec/03-plugins/38-ai-proxy/00-config_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,59 @@ describe(PLUGIN_NAME .. ": (schema)", function()
assert.is_truthy(ok)
end)

it("bedrock model can not support ath.allowed_auth_override", function()
local config = {
route_type = "llm/v1/chat",
auth = {
param_name = "apikey",
param_value = "key",
param_location = "query",
header_name = "Authorization",
header_value = "Bearer token",
allow_auth_override = true,
},
model = {
name = "bedrock",
provider = "bedrock",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://nowhere",
},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.is_truthy(err)
end)

it("gemini model can not support ath.allowed_auth_override", function()
local config = {
route_type = "llm/v1/chat",
auth = {
param_name = "apikey",
param_value = "key",
param_location = "query",
header_name = "Authorization",
header_value = "Bearer token",
allow_auth_override = true,
},
model = {
name = "gemini",
provider = "gemini",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://nowhere",
},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.is_truthy(err)
end)
end)
Loading

0 comments on commit 34b6ca8

Please sign in to comment.