Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai proxy): 3.7 regression fixes rollup #12974

Merged
merged 8 commits into from
May 6, 2024
4 changes: 2 additions & 2 deletions kong/clustering/compat/checkers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ local compatible_checkers = {
if plugin.name == 'ai-proxy' then
local config = plugin.config
if config.model and config.model.options then
if config.model.options.response_streaming then
config.model.options.response_streaming = nil
if config.response_streaming then
config.response_streaming = nil
log_warn_message('configures ' .. plugin.name .. ' plugin with' ..
' response_streaming == nil, because it is not supported' ..
' in this release',
Expand Down
7 changes: 3 additions & 4 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ local function handle_stream_event(event_t, model_info, route_type)
and event_data.usage then
return nil, nil, {
prompt_tokens = nil,
completion_tokens = event_data.meta.usage
and event_data.meta.usage.output_tokens
completion_tokens = event_data.usage.output_tokens
tysoekong marked this conversation as resolved.
Show resolved Hide resolved
or nil,
stop_reason = event_data.delta
and event_data.delta.stop_reason
Expand Down Expand Up @@ -336,7 +335,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transform, response_string, model_info, route_type)
local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -345,7 +344,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down
8 changes: 5 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ function _M.pre_request(conf)

-- for azure provider, all of these must/will be set by now
if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance)
kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id)
kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version)
kong.ctx.plugin.ai_extra_meta = {
["azure_instance_id"] = conf.model.options.azure_instance,
["azure_deployment_id"] = conf.model.options.azure_deployment_id,
["azure_api_version"] = conf.model.options.azure_api_version,
}
end

return true
Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ local transformers_to = {
["llm/v1/chat"] = function(request_table, model_info, route_type)
request_table.model = request_table.model or model_info.name
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model_info, route_type)
request_table.model = model_info.name
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,
Expand Down
58 changes: 41 additions & 17 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ local log_entry_keys = {
TOKENS_CONTAINER = "usage",
META_CONTAINER = "meta",
PAYLOAD_CONTAINER = "payload",
REQUEST_BODY = "ai.payload.request",

-- payload keys
REQUEST_BODY = "request",
RESPONSE_BODY = "response",

-- meta keys
Expand Down Expand Up @@ -264,20 +264,30 @@ function _M.to_ollama(request_table, model)
end

function _M.from_ollama(response_string, model_info, route_type)
local output, _, analytics

local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode ollama response"
end
local output, err, _, analytics

if route_type == "stream/llm/v1/chat" then
local response_table, err = cjson.decode(response_string.data)
if err then
return nil, "failed to decode ollama response"
end

output, _, analytics = handle_stream_event(response_table, model_info, route_type)

elseif route_type == "stream/llm/v1/completions" then
local response_table, err = cjson.decode(response_string.data)
if err then
return nil, "failed to decode ollama response"
end

output, _, analytics = handle_stream_event(response_table, model_info, route_type)

else
local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode ollama response"
end

-- there is no direct field indicating STOP reason, so calculate it manually
local stop_length = (model_info.options and model_info.options.max_tokens) or -1
local stop_reason = "stop"
Expand Down Expand Up @@ -405,14 +415,14 @@ function _M.pre_request(conf, request_table)
request_table[auth_param_name] = auth_param_value
end

if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name)
kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider)
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body())
local plugin_name = conf.__key__:match('plugins:(.-):')
if not plugin_name or plugin_name == "" then
return nil, "no plugin name is being passed by the plugin"
end

kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body())
end

-- log tokens prompt for reports and billing
Expand Down Expand Up @@ -468,7 +478,6 @@ function _M.post_request(conf, response_object)
if not request_analytics_plugin then
request_analytics_plugin = {
[log_entry_keys.META_CONTAINER] = {},
[log_entry_keys.PAYLOAD_CONTAINER] = {},
[log_entry_keys.TOKENS_CONTAINER] = {
[log_entry_keys.PROMPT_TOKEN] = 0,
[log_entry_keys.COMPLETION_TOKEN] = 0,
Expand All @@ -478,11 +487,18 @@ function _M.post_request(conf, response_object)
end

-- Set the model, response, and provider names in the current try context
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id

-- set extra per-provider meta
if kong.ctx.plugin.ai_extra_meta and type(kong.ctx.plugin.ai_extra_meta) == "table" then
for k, v in pairs(kong.ctx.plugin.ai_extra_meta) do
request_analytics_plugin[log_entry_keys.META_CONTAINER][k] = v
end
end

-- Capture openai-format usage stats from the transformed response body
if response_object.usage then
if response_object.usage.prompt_tokens then
Expand All @@ -498,16 +514,24 @@ function _M.post_request(conf, response_object)

-- Log response body if logging payloads is enabled
if conf.logging and conf.logging.log_payloads then
request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string
kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.RESPONSE_BODY), body_string)
end

-- Update context with changed values
request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER] = {
[log_entry_keys.RESPONSE_BODY] = body_string,
}
request_analytics[plugin_name] = request_analytics_plugin
kong.ctx.shared.analytics = request_analytics

if conf.logging and conf.logging.log_statistics then
-- Log analytics data
kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin)
kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.TOKENS_CONTAINER),
request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER])

-- Log meta
kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.META_CONTAINER),
request_analytics_plugin[log_entry_keys.META_CONTAINER])
end

-- log tokens response for reports and billing
Expand Down
6 changes: 0 additions & 6 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ local model_options_schema = {
type = "record",
required = false,
fields = {
{ response_streaming = {
type = "string",
description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.",
required = false,
default = "allow",
one_of = { "allow", "deny", "always" } }},
{ max_tokens = {
type = "integer",
description = "Defines the max_tokens, if using chat or completion models.",
Expand Down
23 changes: 16 additions & 7 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ local function handle_streaming_frame(conf)
if not event_t then
event_t, err = cjson.decode(formatted)
end

if not err then
if not token_t then
token_t = get_token_text(event_t)
Expand All @@ -126,17 +126,24 @@ local function handle_streaming_frame(conf)
(kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4)
end
end

elseif metadata then
kong.ctx.plugin.ai_stream_completion_tokens = metadata.completion_tokens or kong.ctx.plugin.ai_stream_completion_tokens
kong.ctx.plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or kong.ctx.plugin.ai_stream_prompt_tokens
end
end

framebuffer:put("data: ")
framebuffer:put(formatted or "")
framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "")
end

if conf.logging and conf.logging.log_statistics and metadata then
kong.ctx.plugin.ai_stream_completion_tokens =
(kong.ctx.plugin.ai_stream_completion_tokens or 0) +
(metadata.completion_tokens or 0)
or kong.ctx.plugin.ai_stream_completion_tokens
kong.ctx.plugin.ai_stream_prompt_tokens =
(kong.ctx.plugin.ai_stream_prompt_tokens or 0) +
(metadata.prompt_tokens or 0)
or kong.ctx.plugin.ai_stream_prompt_tokens
end
end
end

Expand Down Expand Up @@ -367,10 +374,12 @@ function _M:access(conf)
-- check if the user has asked for a stream, and/or if
-- we are forcing all requests to be of streaming type
if request_table and request_table.stream or
(conf_m.model.options and conf_m.model.options.response_streaming) == "always" then
(conf_m.response_streaming and conf_m.response_streaming == "always") then
request_table.stream = true

-- this condition will only check if user has tried
-- to activate streaming mode within their request
if conf_m.model.options and conf_m.model.options.response_streaming == "deny" then
if conf_m.response_streaming and conf_m.response_streaming == "deny" then
return bad_request("response streaming is not enabled for this LLM")
end

Expand Down
20 changes: 19 additions & 1 deletion kong/plugins/ai-proxy/schema.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
local typedefs = require("kong.db.schema.typedefs")
local llm = require("kong.llm")
local deep_copy = require("kong.tools.utils").deep_copy

local this_schema = deep_copy(llm.config_schema)

local ai_proxy_only_config = {
{
response_streaming = {
type = "string",
description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.",
required = false,
default = "allow",
one_of = { "allow", "deny", "always" }},
},
}

for i, v in pairs(ai_proxy_only_config) do
this_schema.fields[#this_schema.fields+1] = v
end

return {
name = "ai-proxy",
fields = {
{ protocols = typedefs.protocols_http },
{ consumer = typedefs.no_consumer },
{ service = typedefs.no_service },
{ config = llm.config_schema },
{ config = this_schema },
},
}
4 changes: 2 additions & 2 deletions spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
name = "ai-proxy",
enabled = true,
config = {
response_streaming = "allow", -- becomes nil
route_type = "preserve", -- becomes 'llm/v1/chat'
auth = {
header_name = "header",
Expand All @@ -491,7 +492,6 @@ describe("CP/DP config compat transformations #" .. strategy, function()
options = {
max_tokens = 512,
temperature = 0.5,
response_streaming = "allow", -- becomes nil
upstream_path = "/anywhere", -- becomes nil
},
},
Expand All @@ -500,7 +500,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- ]]

local expected_ai_proxy_prior_37 = utils.cycle_aware_deep_copy(ai_proxy)
expected_ai_proxy_prior_37.config.model.options.response_streaming = nil
expected_ai_proxy_prior_37.config.response_streaming = nil
expected_ai_proxy_prior_37.config.model.options.upstream_path = nil
expected_ai_proxy_prior_37.config.route_type = "llm/v1/chat"

Expand Down
5 changes: 2 additions & 3 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-3.5-turbo',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down Expand Up @@ -775,8 +774,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.is_number(log_message.response.size)

-- test request bodies
assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true)
assert.matches('"role": "user"', log_message.ai.payload.request, nil, true)
assert.matches('"content": "What is 1 + 1?"', log_message.ai['ai-proxy'].payload.request, nil, true)
assert.matches('"role": "user"', log_message.ai['ai-proxy'].payload.request, nil, true)

-- test response bodies
assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-4',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-4',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down
Loading