diff --git a/apisix/plugins/ai-drivers/openai-base.lua b/apisix/plugins/ai-drivers/openai-base.lua index 4f279bbc3eab..959c14b85c23 100644 --- a/apisix/plugins/ai-drivers/openai-base.lua +++ b/apisix/plugins/ai-drivers/openai-base.lua @@ -131,10 +131,12 @@ local function read_response(conf, ctx, res, response_filter) core.log.info("got token usage from ai service: ", core.json.delay_encode(data.usage)) ctx.llm_raw_usage = data.usage + local pt = data.usage.prompt_tokens or data.usage.input_tokens or 0 + local ct = data.usage.completion_tokens or data.usage.output_tokens or 0 ctx.ai_token_usage = { - prompt_tokens = data.usage.prompt_tokens or 0, - completion_tokens = data.usage.completion_tokens or 0, - total_tokens = data.usage.total_tokens or 0, + prompt_tokens = pt, + completion_tokens = ct, + total_tokens = data.usage.total_tokens or (pt + ct), } ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens @@ -188,9 +190,13 @@ local function read_response(conf, ctx, res, response_filter) ctx.ai_token_usage = {} if type(res_body.usage) == "table" then ctx.llm_raw_usage = res_body.usage - ctx.ai_token_usage.prompt_tokens = res_body.usage.prompt_tokens or 0 - ctx.ai_token_usage.completion_tokens = res_body.usage.completion_tokens or 0 - ctx.ai_token_usage.total_tokens = res_body.usage.total_tokens or 0 + ctx.ai_token_usage.prompt_tokens = res_body.usage.prompt_tokens + or res_body.usage.input_tokens or 0 + ctx.ai_token_usage.completion_tokens = res_body.usage.completion_tokens + or res_body.usage.output_tokens or 0 + ctx.ai_token_usage.total_tokens = res_body.usage.total_tokens + or (ctx.ai_token_usage.prompt_tokens + + ctx.ai_token_usage.completion_tokens) end ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens or 0 ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens or 0