Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai): fix preserve route_type broken issue in refactor #13576

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ function _M.configure_request(conf)
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"
)

parsed_url = socket_url.parse(url)
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, 3re that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
4 changes: 3 additions & 1 deletion kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
end

local parsed_url = socket_url.parse(f_url)
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method

-- do the IAM auth and signature headers
identity_interface.interface.config.signatureVersion = "v4"
Expand Down Expand Up @@ -439,6 +439,8 @@ function _M.configure_request(conf, aws_sdk)
-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

ai_shared.override_upstream_url(parsed_url, conf)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))
Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)


-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
3 changes: 3 additions & 0 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ function _M.configure_request(conf, identity_interface)
parsed_url.path = conf.model.options.upstream_path
end

ai_shared.override_upstream_url(parsed_url, conf)


-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ end
function _M.configure_request(conf)
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/"

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/"

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ function _M.configure_request(conf)
or "/"
end

ai_shared.override_upstream_url(parsed_url, conf)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
7 changes: 7 additions & 0 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,13 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor)
return query_cost, nil
end

function _M.override_upstream_url(parsed_url, conf)
if conf.route_type == "preserve" then
parsed_url.path = conf.model.options and conf.model.options.upstream_path
or kong.request.get_path()
end
end

-- for unit tests
_M._count_words = count_words

Expand Down
2 changes: 1 addition & 1 deletion kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ function _M:access(conf)
end

-- execute pre-request hooks for "all" drivers before set new body
local ok, err = ai_shared.pre_request(conf_m, parsed_request_body)
local ok, err = ai_shared.pre_request(conf_m, parsed_request_body or request_table)
if not ok then
return bail(400, err)
end
Expand Down
Loading