Skip to content

Commit

Permalink
fix(ai-proxy): (Bedrock)(AG-xy) fixed tools-functions calls coming ba…
Browse files Browse the repository at this point in the history
…ck empty
  • Loading branch information
tysoekong committed Oct 16, 2024
1 parent 9f6bc6b commit 4b69668
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 6 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed a bug where tools (function) calls to Bedrock would return empty results."
type: bugfix
scope: Plugin
123 changes: 117 additions & 6 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ local _OPENAI_ROLE_MAPPING = {
["system"] = "assistant",
["user"] = "user",
["assistant"] = "assistant",
["tool"] = "user",
}

local _OPENAI_STOP_REASON_MAPPING = {
["max_tokens"] = "length",
["end_turn"] = "stop",
["tool_use"] = "tool_calls",
}

_M.bedrock_unsupported_system_role_patterns = {
Expand Down Expand Up @@ -51,6 +58,48 @@ local function to_tool_config(request_table)
}
end

local function to_tools(in_tools)
local out_tools

for i, v in ipairs(in_tools) do
if v['function'] then
out_tools = out_tools or {}

out_tools[i] = {
toolSpec = {
name = v['function'].name,
description = v['function'].description,
inputSchema = {
json = v['function'].parameters,
},
},
}
end
end

return out_tools
end

local function from_tool_call_response(tool_use)
local arguments

if tool_use['input'] and next(tool_use['input']) then
arguments = cjson.encode(tool_use['input'])
end

return {
-- set explicit numbering to ensure ordering in later modifications
[1] = {
['function'] = {
arguments = arguments,
name = tool_use.name,
},
id = tool_use.toolUseId,
type = "function",
},
}
end

local function handle_stream_event(event_t, model_info, route_type)
local new_event, metadata

Expand Down Expand Up @@ -113,7 +162,7 @@ local function handle_stream_event(event_t, model_info, route_type)
[1] = {
delta = {},
index = 0,
finish_reason = body.stopReason,
finish_reason = _OPENAI_STOP_REASON_MAPPING[body.stopReason] or "stop",
logprobs = cjson.null,
},
},
Expand Down Expand Up @@ -144,7 +193,7 @@ local function handle_stream_event(event_t, model_info, route_type)
end

local function to_bedrock_chat_openai(request_table, model_info, route_type)
if not request_table then -- try-catch type mechanism
if not request_table then
local err = "empty request table received for transformation"
ngx.log(ngx.ERR, "[bedrock] ", err)
return nil, nil, err
Expand All @@ -164,16 +213,60 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
if v.role and v.role == "system" then
system_prompts[#system_prompts+1] = { text = v.content }

elseif v.role and v.role == "tool" then
local tool_execution_content, err = cjson.decode(v.content)
if err then
return nil, nil, "failed to decode function response arguments: " .. err
end

local content = {
{
toolResult = {
toolUseId = v.tool_call_id,
content = {
{
json = tool_execution_content,
},
},
status = v.status,
},
},
}

new_r.messages = new_r.messages or {}
table_insert(new_r.messages, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
content = content,
})

else
local content
if type(v.content) == "table" then
content = v.content

elseif v.tool_calls and (type(v.tool_calls) == "table") then
local inputs, err = cjson.decode(v.tool_calls[1]['function'].arguments)
if err then
return nil, nil, "failed to decode function response arguments from assistant: " .. err
end

content = {
{
toolUse = {
toolUseId = v.tool_calls[1].id,
name = v.tool_calls[1]['function'].name,
input = inputs,
},
},
}

else
content = {
{
text = v.content or ""
},
}

end

-- for any other role, just construct the chat history as 'parts.text' type
Expand All @@ -199,9 +292,18 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)

new_r.inferenceConfig = to_bedrock_generation_config(request_table)

-- backwards compatibility
new_r.toolConfig = request_table.bedrock
and request_table.bedrock.toolConfig
and to_tool_config(request_table)

if request_table.tools
and type(request_table.tools) == "table"
and #request_table.tools > 0 then

new_r.toolConfig = new_r.toolConfig or {}
new_r.toolConfig.tools = to_tools(request_table.tools)
end

new_r.additionalModelRequestFields = request_table.bedrock
and request_table.bedrock.additionalModelRequestFields
Expand All @@ -219,7 +321,6 @@ local function from_bedrock_chat_openai(response, model_info, route_type)
return nil, err_client
end

-- messages/choices table is only 1 size, so don't need to static allocate
local client_response = {}
client_response.choices = {}

Expand All @@ -229,13 +330,23 @@ local function from_bedrock_chat_openai(response, model_info, route_type)
and #response.output.message.content > 0
and response.output.message.content[1].text then

client_response.choices[1] = {
local tool_use, err
if #response.output.message.content > 1 and response.output.message.content[2].toolUse then
tool_use, err = from_tool_call_response(response.output.message.content[2].toolUse)

if err then
return nil, fmt("unable to process function call response arguments: %s", err)
end
end

client_response.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.output.message.content[1].text,
tool_calls = tool_use,
},
finish_reason = string_lower(response.stopReason),
finish_reason = _OPENAI_STOP_REASON_MAPPING[response.stopReason] or "stop",
}
client_response.object = "chat.completion"
client_response.model = model_info.name
Expand Down Expand Up @@ -294,7 +405,7 @@ function _M.to_format(request_table, model_info, route_type)
-- do nothing
return request_table, nil, nil
end

if not transformers_to[route_type] then
return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type)
end
Expand Down

0 comments on commit 4b69668

Please sign in to comment.