Skip to content

Commit 460316c

Browse files
authored
feat: add ai-prompt-guard plugin (#12008)
1 parent 4207de5 commit 460316c

File tree

7 files changed

+659
-0
lines changed

7 files changed

+659
-0
lines changed

apisix/cli/config.lua

+1
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ local _M = {
216216
"body-transformer",
217217
"ai-prompt-template",
218218
"ai-prompt-decorator",
219+
"ai-prompt-guard",
219220
"ai-rag",
220221
"ai-aws-content-moderation",
221222
"proxy-mirror",

apisix/plugins/ai-prompt-guard.lua

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local ngx = ngx
19+
local ipairs = ipairs
20+
local table = table
21+
local re_compile = require("resty.core.regex").re_match_compile
22+
local re_find = ngx.re.find
23+
24+
local plugin_name = "ai-prompt-guard"
25+
26+
local schema = {
27+
type = "object",
28+
properties = {
29+
match_all_roles = {
30+
type = "boolean",
31+
default = false,
32+
},
33+
match_all_conversation_history = {
34+
type = "boolean",
35+
default = false,
36+
},
37+
allow_patterns = {
38+
type = "array",
39+
items = {type = "string"},
40+
default = {},
41+
},
42+
deny_patterns = {
43+
type = "array",
44+
items = {type = "string"},
45+
default = {},
46+
},
47+
},
48+
}
49+
50+
local _M = {
51+
version = 0.1,
52+
priority = 1072,
53+
name = plugin_name,
54+
schema = schema,
55+
}
56+
57+
function _M.check_schema(conf)
58+
local ok, err = core.schema.check(schema, conf)
59+
if not ok then
60+
return false, err
61+
end
62+
63+
-- Validate allow_patterns
64+
for _, pattern in ipairs(conf.allow_patterns) do
65+
local compiled = re_compile(pattern, "jou")
66+
if not compiled then
67+
return false, "invalid allow_pattern: " .. pattern
68+
end
69+
end
70+
71+
-- Validate deny_patterns
72+
for _, pattern in ipairs(conf.deny_patterns) do
73+
local compiled = re_compile(pattern, "jou")
74+
if not compiled then
75+
return false, "invalid deny_pattern: " .. pattern
76+
end
77+
end
78+
79+
return true
80+
end
81+
82+
local function get_content_to_check(conf, messages)
83+
if conf.match_all_conversation_history then
84+
return messages
85+
end
86+
local contents = {}
87+
if #messages > 0 then
88+
local last_msg = messages[#messages]
89+
if last_msg then
90+
core.table.insert(contents, last_msg)
91+
end
92+
end
93+
return contents
94+
end
95+
96+
function _M.access(conf, ctx)
97+
local body = core.request.get_body()
98+
if not body then
99+
core.log.error("Empty request body")
100+
return 400, {message = "Empty request body"}
101+
end
102+
103+
local json_body, err = core.json.decode(body)
104+
if err then
105+
return 400, {message = err}
106+
end
107+
108+
local messages = json_body.messages or {}
109+
messages = get_content_to_check(conf, messages)
110+
if not conf.match_all_roles then
111+
-- filter to only user messages
112+
local new_messages = {}
113+
for _, msg in ipairs(messages) do
114+
if msg.role == "user" then
115+
core.table.insert(new_messages, msg)
116+
end
117+
end
118+
messages = new_messages
119+
end
120+
if #messages == 0 then --nothing to check
121+
return 200
122+
end
123+
-- extract only messages
124+
local content = {}
125+
for _, msg in ipairs(messages) do
126+
if msg.content then
127+
core.table.insert(content, msg.content)
128+
end
129+
end
130+
local content_to_check = table.concat(content, " ")
131+
-- Allow patterns check
132+
if #conf.allow_patterns > 0 then
133+
local any_allowed = false
134+
for _, pattern in ipairs(conf.allow_patterns) do
135+
if re_find(content_to_check, pattern, "jou") then
136+
any_allowed = true
137+
break
138+
end
139+
end
140+
if not any_allowed then
141+
return 400, {message = "Request doesn't match allow patterns"}
142+
end
143+
end
144+
145+
-- Deny patterns check
146+
for _, pattern in ipairs(conf.deny_patterns) do
147+
if re_find(content_to_check, pattern, "jou") then
148+
return 400, {message = "Request contains prohibited content"}
149+
end
150+
end
151+
end
152+
153+
return _M

conf/config.yaml.example

+1
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ plugins: # plugin list (sorted by priority)
480480
- body-transformer # priority: 1080
481481
- ai-prompt-template # priority: 1071
482482
- ai-prompt-decorator # priority: 1070
483+
- ai-prompt-guard # priority: 1072
483484
- ai-rag # priority: 1060
484485
- ai-aws-content-moderation # priority: 1040 TODO: compare priority with other ai plugins
485486
- proxy-mirror # priority: 1010

docs/en/latest/config.json

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"plugins/ext-plugin-post-resp",
8383
"plugins/inspect",
8484
"plugins/ocsp-stapling",
85+
"plugins/ai-prompt-guard",
8586
"plugins/ai-aws-content-moderation"
8687
]
8788
},
+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
---
2+
title: ai-prompt-guard
3+
keywords:
4+
- Apache APISIX
5+
- API Gateway
6+
- Plugin
7+
- ai-prompt-guard
8+
description: This document contains information about the Apache APISIX ai-prompt-guard Plugin.
9+
---
10+
11+
<!--
12+
#
13+
# Licensed to the Apache Software Foundation (ASF) under one or more
14+
# contributor license agreements. See the NOTICE file distributed with
15+
# this work for additional information regarding copyright ownership.
16+
# The ASF licenses this file to You under the Apache License, Version 2.0
17+
# (the "License"); you may not use this file except in compliance with
18+
# the License. You may obtain a copy of the License at
19+
#
20+
# http://www.apache.org/licenses/LICENSE-2.0
21+
#
22+
# Unless required by applicable law or agreed to in writing, software
23+
# distributed under the License is distributed on an "AS IS" BASIS,
24+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+
# See the License for the specific language governing permissions and
26+
# limitations under the License.
27+
#
28+
-->
29+
30+
## Description
31+
32+
The `ai-prompt-guard` plugin safeguards your AI endpoints by inspecting and validating incoming prompt messages. It checks the content of requests against user-defined allowed and denied patterns to ensure that only approved inputs are processed. Based on its configuration, the plugin can either examine just the latest message or the entire conversation history, and it can be set to check prompts from all roles or only from end users.
33+
34+
When both **allow** and **deny** patterns are configured, the plugin first ensures that at least one allowed pattern is matched. If none match, the request is rejected with a _"Request doesn't match allow patterns"_ error. If an allowed pattern is found, it then checks for any occurrences of denied patterns—rejecting the request with a _"Request contains prohibited content"_ error if any are detected.
35+
36+
## Plugin Attributes
37+
38+
| **Field** | **Required** | **Type** | **Description** |
39+
| ------------------------------ | ------------ | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
40+
| match_all_roles | No | boolean | If set to `true`, the plugin will check prompt messages from all roles. Otherwise, it only validates when its role is `"user"`. Default is `false`. |
41+
| match_all_conversation_history | No | boolean | When enabled, all messages in the conversation history are concatenated and checked. If `false`, only the content of the last message is examined. Default is `false`. |
42+
| allow_patterns | No | array | A list of regex patterns. When provided, the prompt must match **at least one** pattern to be considered valid. |
43+
| deny_patterns | No | array | A list of regex patterns. If any of these patterns match the prompt content, the request is rejected. |
44+
45+
## Example usage
46+
47+
Create a route with the `ai-prompt-guard` plugin like so:
48+
49+
```shell
50+
curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \
51+
-H "X-API-KEY: ${ADMIN_API_KEY}" \
52+
-d '{
53+
"uri": "/v1/chat/completions",
54+
"plugins": {
55+
"ai-prompt-guard": {
56+
"match_all_roles": true,
57+
"allow_patterns": [
58+
"goodword"
59+
],
60+
"deny_patterns": [
61+
"badword"
62+
]
63+
}
64+
},
65+
"upstream": {
66+
"type": "roundrobin",
67+
"nodes": {
68+
"api.openai.com:443": 1
69+
},
70+
"pass_host": "node",
71+
"scheme": "https"
72+
}
73+
}'
74+
```
75+
76+
Now send a request:
77+
78+
```shell
79+
curl http://127.0.0.1:9080/v1/chat/completions -i -XPOST -H 'Content-Type: application/json' -d '{
80+
"model": "gpt-4",
81+
"messages": [{ "role": "user", "content": "badword request" }]
82+
}' -H "Authorization: Bearer <your token here>"
83+
```
84+
85+
The request will fail with 400 error and following response.
86+
87+
```bash
88+
{"message":"Request doesn't match allow patterns"}
89+
```

t/admin/plugins.t

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ opa
9494
authz-keycloak
9595
proxy-cache
9696
body-transformer
97+
ai-prompt-guard
9798
ai-prompt-template
9899
ai-prompt-decorator
99100
ai-rag

0 commit comments

Comments
 (0)