Skip to content

Commit

Permalink
refactor(router): categorize router fields to simplify logic (#11411)
Browse files Browse the repository at this point in the history
Summary

Categorize atc fields to fields header_fields query_fields,
then we can simplify the logic in select().
  • Loading branch information
chronolaw committed Feb 26, 2024
1 parent 4740f9c commit a21d385
Showing 1 changed file with 44 additions and 48 deletions.
92 changes: 44 additions & 48 deletions kong/router/atc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,35 +158,34 @@ local function add_atc_matcher(inst, route, route_id,
end


local function is_http_headers_field(field)
return field:sub(1, 13) == "http.headers."
end

local function categorize_fields(fields)

local function has_header_matching_field(fields)
for _, field in ipairs(fields) do
if is_http_headers_field(field) then
return true
end
if not is_http then
return fields, nil, nil
end

return false
end
local basic = {}
local headers = {}
local queries = {}

-- 13 bytes, same len for "http.queries."
local PREFIX_LEN = 13 -- #"http.headers."

local function is_http_queries_field(field)
return field:sub(1, 13) == "http.queries."
end
for _, field in ipairs(fields) do
local prefix = field:sub(1, PREFIX_LEN)

if prefix == "http.headers." then
headers[field:sub(PREFIX_LEN + 1)] = field

local function has_query_matching_field(fields)
for _, field in ipairs(fields) do
if is_http_queries_field(field) then
return true
elseif prefix == "http.queries." then
queries[field:sub(PREFIX_LEN + 1)] = field

else
table.insert(basic, field)
end
end

return false
return basic, headers, queries
end


Expand Down Expand Up @@ -230,18 +229,16 @@ local function new_from_scratch(routes, get_exp_and_priority)
yield(true, phase)
end

local fields = inst:get_fields()
local match_headers = has_header_matching_field(fields)
local match_queries = has_query_matching_field(fields)
local fields, header_fields, query_fields = categorize_fields(inst:get_fields())

return setmetatable({
schema = CACHED_SCHEMA,
router = inst,
routes = routes_t,
services = services_t,
fields = fields,
match_headers = match_headers,
match_queries = match_queries,
header_fields = header_fields,
query_fields = query_fields,
updated_at = new_updated_at,
rebuilding = false,
}, _MT)
Expand Down Expand Up @@ -325,11 +322,11 @@ local function new_from_previous(routes, get_exp_and_priority, old_router)
yield(true, phase)
end

local fields = inst:get_fields()
local fields, header_fields, query_fields = categorize_fields(inst:get_fields())

old_router.fields = fields
old_router.match_headers = has_header_matching_field(fields)
old_router.match_queries = has_query_matching_field(fields)
old_router.header_fields = header_fields
old_router.query_fields = query_fields
old_router.updated_at = new_updated_at
old_router.rebuilding = false

Expand Down Expand Up @@ -455,12 +452,16 @@ function _M:select(req_method, req_uri, req_host, req_scheme,
return nil, err
end

elseif is_http_headers_field(field) then
if not req_headers then
goto continue
end
else -- unknown field
error("unknown router matching schema field: " .. field)

end -- if field

end -- for self.fields

if req_headers then
for h, field in pairs(self.header_fields) do

local h = field:sub(14)
local v = req_headers[h]

if type(v) == "string" then
Expand All @@ -478,14 +479,14 @@ function _M:select(req_method, req_uri, req_host, req_scheme,
end
end -- if type(v)

-- if v is nil or others, goto continue
-- if v is nil or others, ignore

elseif is_http_queries_field(field) then
if not req_queries then
goto continue
end
end -- for self.header_fields
end -- req_headers

if req_queries then
for n, field in pairs(self.query_fields) do

local n = field:sub(14)
local v = req_queries[n]

-- the query parameter has only one value, like /?foo=bar
Expand Down Expand Up @@ -514,15 +515,10 @@ function _M:select(req_method, req_uri, req_host, req_scheme,
end
end -- if type(v)

-- if v is nil or others, goto continue
-- if v is nil or others, ignore

else -- unknown field
error("unknown router matching schema field: " .. field)

end -- if field

::continue::
end -- for self.fields
end -- for self.query_fields
end -- req_queries

local matched = self.router:execute(c)
if not matched then
Expand Down Expand Up @@ -642,7 +638,7 @@ function _M:exec(ctx)
local sni = server_name()

local headers, headers_key
if self.match_headers then
if not is_empty_field(self.header_fields) then
headers = get_http_params(get_headers, "headers", "lua_max_req_headers")

headers["host"] = nil
Expand All @@ -651,7 +647,7 @@ function _M:exec(ctx)
end

local queries, queries_key
if self.match_queries then
if not is_empty_field(self.query_fields) then
queries = get_http_params(get_uri_args, "queries", "lua_max_uri_args")

queries_key = get_queries_key(queries)
Expand Down

0 comments on commit a21d385

Please sign in to comment.