Skip to content

Commit

Permalink
fix(ai): fix preserve route_type broken issue in refactor (#13576)
Browse files Browse the repository at this point in the history
* fix(ai): fix preserve route_type broken issue in refactor

(cherry picked from commit 4d7934b)
  • Loading branch information
oowl authored and github-actions[bot] committed Aug 27, 2024
1 parent 19a6877 commit 537fd9a
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 2 deletions.
2 changes: 2 additions & 0 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ function _M.configure_request(conf)
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"
)

parsed_url = socket_url.parse(url)
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, 3re that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
4 changes: 3 additions & 1 deletion kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
end

local parsed_url = socket_url.parse(f_url)
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method

-- do the IAM auth and signature headers
identity_interface.interface.config.signatureVersion = "v4"
Expand Down Expand Up @@ -439,6 +439,8 @@ function _M.configure_request(conf, aws_sdk)
-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

ai_shared.override_upstream_url(parsed_url, conf)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))
Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)


-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ function _M.configure_request(conf, identity_interface)
parsed_url.path = conf.model.options.upstream_path
end

ai_shared.override_upstream_url(parsed_url, conf)


-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ end
function _M.configure_request(conf)
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/"

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/"

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
7 changes: 7 additions & 0 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,13 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor)
return query_cost, nil
end

function _M.override_upstream_url(parsed_url, conf)
if conf.route_type == "preserve" then
parsed_url.path = conf.model.options and conf.model.options.upstream_path
or kong.request.get_path()
end
end

-- for unit tests
_M._count_words = count_words

Expand Down
2 changes: 1 addition & 1 deletion kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ function _M:access(conf)
end

-- execute pre-request hooks for "all" drivers before set new body
local ok, err = ai_shared.pre_request(conf_m, parsed_request_body)
local ok, err = ai_shared.pre_request(conf_m, parsed_request_body or request_table)
if not ok then
return bail(400, err)
end
Expand Down

0 comments on commit 537fd9a

Please sign in to comment.